1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
//! Validation via repeated random shuffling
//! of the data and splitting into a training and test set.
//!
//! # Examples
//!
//! ```
//! use rustlearn::prelude::*;
//! use rustlearn::datasets::iris;
//! use rustlearn::cross_validation::ShuffleSplit;
//!
//!
//! let (X, y) = iris::load_data();
//!
//! let num_splits = 10;
//! let test_percentage = 0.2;
//!
//! for (train_idx, test_idx) in ShuffleSplit::new(X.rows(), num_splits, test_percentage) {
//!
//!     let X_train = X.get_rows(&train_idx);
//!     let y_train = y.get_rows(&train_idx);
//!     let X_test = X.get_rows(&test_idx);
//!     let y_test = y.get_rows(&test_idx);
//!
//!     // Model fitting happens here
//! }
//! ```

use std::iter::Iterator;

use rand;
use rand::Rng;


pub struct ShuffleSplit {
    n: usize,
    n_iter: usize,
    test_size: f32,
    rng: rand::StdRng,
    iter: usize,
}


impl ShuffleSplit {
    /// Create a new instance of the shuffle split utility.
    ///
    /// Iterating over it will split the dataset of size `n_samples`
    /// into a train set of `(1.0 - test_size) * n_samples` rows
    /// and a test set of `test_size * n_samples` rows, `n_iter` times.
    pub fn new(n_samples: usize, n_iter: usize, test_size: f32) -> ShuffleSplit {
        ShuffleSplit {
            n: n_samples,
            n_iter: n_iter,
            test_size: test_size,
            rng: rand::StdRng::new().unwrap(),
            iter: 0,
        }
    }

    /// Set the random number generator.
    pub fn set_rng(&mut self, rng: rand::StdRng) {
        self.rng = rng;
    }

    fn get_shuffled_indices(&mut self) -> Vec<usize> {
        let mut indices = (0..self.n).collect::<Vec<usize>>();
        self.rng.shuffle(&mut indices);

        indices
    }
}


impl Iterator for ShuffleSplit {
    type Item = (Vec<usize>, Vec<usize>);
    fn next(&mut self) -> Option<(Vec<usize>, Vec<usize>)> {

        let ret = if self.iter < self.n_iter {
            let split_idx: usize = (self.n as f32 * (1.0 - self.test_size)).floor() as usize;
            let shuffled_indices = self.get_shuffled_indices();
            let (train, test) = shuffled_indices.split_at(split_idx);
            Some((train.to_owned(), test.to_owned()))
        } else {
            None
        };

        self.iter += 1;
        ret
    }
}


#[cfg(test)]
mod tests {
    use super::*;

    extern crate rand;

    use rand::{SeedableRng, StdRng};

    #[test]
    fn iteration() {
        let split = ShuffleSplit::new(100, 4, 0.2);
        let mut count = 0;

        for _ in split {
            count += 1;
        }

        assert!(count == 4);
    }


    #[test]
    fn size_split() {
        let split = ShuffleSplit::new(100, 4, 0.2);

        for (train, test) in split {
            assert!(train.len() == 80);
            assert!(test.len() == 20);
        }
    }


    #[test]
    #[should_panic]
    fn shuffle_differs() {
        let set1 = ShuffleSplit::new(1000, 1, 0.2).collect::<Vec<_>>();
        let set2 = ShuffleSplit::new(1000, 1, 0.2).collect::<Vec<_>>();

        assert!(set1[0].0 == set2[0].0);
    }


    #[test]
    fn set_rng() {

        let seed: &[_] = &[1, 2, 3, 4];
        let rng1: StdRng = SeedableRng::from_seed(seed);
        let rng2: StdRng = SeedableRng::from_seed(seed);

        let mut split1 = ShuffleSplit::new(1000, 1, 0.2);
        let mut split2 = ShuffleSplit::new(1000, 1, 0.2);

        split1.set_rng(rng1);
        split2.set_rng(rng2);

        let set1 = split1.collect::<Vec<_>>();
        let set2 = split2.collect::<Vec<_>>();

        assert!(set1[0].0 == set2[0].0);
    }

}