You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_split.py 8.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Fri Jun 24 11:13:26 2022
  5. @author: ljia
  6. Reference: scikit-learn.
  7. """
  8. from abc import abstractmethod
  9. import numbers
  10. import warnings
  11. import numpy as np
  12. from sklearn.utils import check_random_state, check_array, column_or_1d, indexable
  13. from sklearn.utils.validation import _num_samples
  14. from sklearn.utils.multiclass import type_of_target
  15. class BaseCrossValidatorWithValid(object):
  16. """Base class for all cross-validators.
  17. Implementations must define `_iter_valid_test_masks` or `_iter_valid_stest_indices`.
  18. """
  19. def split(self, X, y=None, groups=None):
  20. """Generate indices to split data into training, valid, and test set.
  21. Parameters
  22. ----------
  23. X : array-like of shape (n_samples, n_features)
  24. Training data, where `n_samples` is the number of samples
  25. and `n_features` is the number of features.
  26. y : array-like of shape (n_samples,)
  27. The target variable for supervised learning problems.
  28. groups : array-like of shape (n_samples,), default=None
  29. Group labels for the samples used while splitting the dataset into
  30. train/test set.
  31. Yields
  32. ------
  33. train : ndarray
  34. The training set indices for that split.
  35. valid : ndarray
  36. The valid set indices for that split.
  37. test : ndarray
  38. The testing set indices for that split.
  39. """
  40. X, y, groups = indexable(X, y, groups)
  41. indices = np.arange(_num_samples(X))
  42. for valid_index, test_index in self._iter_valid_test_masks(X, y, groups):
  43. train_index = indices[np.logical_not(np.logical_or(valid_index, test_index))]
  44. valid_index = indices[valid_index]
  45. test_index = indices[test_index]
  46. yield train_index, valid_index, test_index
  47. # Since subclasses must implement either _iter_valid_test_masks or
  48. # _iter_valid_test_indices, neither can be abstract.
  49. def _iter_valid_test_masks(self, X=None, y=None, groups=None):
  50. """Generates boolean masks corresponding to valid and test sets.
  51. By default, delegates to _iter_valid_test_indices(X, y, groups)
  52. """
  53. for valid_index, test_index in self._iter_valid_test_indices(X, y, groups):
  54. valid_mask = np.zeros(_num_samples(X), dtype=bool)
  55. test_mask = np.zeros(_num_samples(X), dtype=bool)
  56. valid_mask[valid_index] = True
  57. test_mask[test_index] = True
  58. yield valid_mask, test_mask
  59. def _iter_valid_test_indices(self, X=None, y=None, groups=None):
  60. """Generates integer indices corresponding to valid and test sets."""
  61. raise NotImplementedError
  62. @abstractmethod
  63. def get_n_splits(self, X=None, y=None, groups=None):
  64. """Returns the number of splitting iterations in the cross-validator"""
  65. def __repr__(self):
  66. return _build_repr(self)
  67. class _BaseKFoldWithValid(BaseCrossValidatorWithValid):
  68. """Base class for KFoldWithValid, GroupKFoldWithValid, and StratifiedKFoldWithValid"""
  69. @abstractmethod
  70. def __init__(self, n_splits, *, stratify, shuffle, random_state):
  71. if not isinstance(n_splits, numbers.Integral):
  72. raise ValueError(
  73. 'The number of folds must be of Integral type. '
  74. '%s of type %s was passed.' % (n_splits, type(n_splits))
  75. )
  76. n_splits = int(n_splits)
  77. if n_splits <= 2:
  78. raise ValueError(
  79. 'k-fold cross-validation requires at least one'
  80. ' train/valid/test split by setting n_splits=3 or more,'
  81. ' got n_splits={0}.'.format(n_splits)
  82. )
  83. if not isinstance(shuffle, bool):
  84. raise TypeError('shuffle must be True or False; got {0}'.format(shuffle))
  85. if not shuffle and random_state is not None: # None is the default
  86. raise ValueError(
  87. 'Setting a random_state has no effect since shuffle is '
  88. 'False. You should leave '
  89. 'random_state to its default (None), or set shuffle=True.',
  90. )
  91. self.n_splits = n_splits
  92. self.stratify = stratify
  93. self.shuffle = shuffle
  94. self.random_state = random_state
  95. def split(self, X, y=None, groups=None):
  96. """Generate indices to split data into training, valid and test set."""
  97. X, y, groups = indexable(X, y, groups)
  98. n_samples = _num_samples(X)
  99. if self.n_splits > n_samples:
  100. raise ValueError(
  101. (
  102. 'Cannot have number of splits n_splits={0} greater'
  103. ' than the number of samples: n_samples={1}.'
  104. ).format(self.n_splits, n_samples)
  105. )
  106. for train, valid, test in super().split(X, y, groups):
  107. yield train, valid, test
  108. class KFoldWithValid(_BaseKFoldWithValid):
  109. def __init__(
  110. self,
  111. n_splits=5,
  112. *,
  113. stratify=False,
  114. shuffle=False,
  115. random_state=None
  116. ):
  117. super().__init__(
  118. n_splits=n_splits,
  119. stratify=stratify,
  120. shuffle=shuffle,
  121. random_state=random_state
  122. )
  123. def _make_valid_test_folds(self, X, y=None):
  124. rng = check_random_state(self.random_state)
  125. y = np.asarray(y)
  126. type_of_target_y = type_of_target(y)
  127. allowed_target_types = ('binary', 'multiclass')
  128. if type_of_target_y not in allowed_target_types:
  129. raise ValueError(
  130. 'Supported target types are: {}. Got {!r} instead.'.format(
  131. allowed_target_types, type_of_target_y
  132. )
  133. )
  134. y = column_or_1d(y)
  135. _, y_idx, y_inv = np.unique(y, return_index=True, return_inverse=True)
  136. # y_inv encodes y according to lexicographic order. We invert y_idx to
  137. # map the classes so that they are encoded by order of appearance:
  138. # 0 represents the first label appearing in y, 1 the second, etc.
  139. _, class_perm = np.unique(y_idx, return_inverse=True)
  140. y_encoded = class_perm[y_inv]
  141. n_classes = len(y_idx)
  142. y_counts = np.bincount(y_encoded)
  143. min_groups = np.min(y_counts)
  144. if np.all(self.n_splits > y_counts):
  145. raise ValueError(
  146. "n_splits=%d cannot be greater than the"
  147. " number of members in each class." % (self.n_splits)
  148. )
  149. if self.n_splits > min_groups:
  150. warnings.warn(
  151. "The least populated class in y has only %d"
  152. " members, which is less than n_splits=%d."
  153. % (min_groups, self.n_splits),
  154. UserWarning,
  155. )
  156. # Determine the optimal number of samples from each class in each fold,
  157. # using round robin over the sorted y. (This can be done direct from
  158. # counts, but that code is unreadable.)
  159. y_order = np.sort(y_encoded)
  160. allocation = np.asarray(
  161. [
  162. np.bincount(y_order[i :: self.n_splits], minlength=n_classes)
  163. for i in range(self.n_splits)
  164. ]
  165. )
  166. # To maintain the data order dependencies as best as possible within
  167. # the stratification constraint, we assign samples from each class in
  168. # blocks (and then mess that up when shuffle=True).
  169. test_folds = np.empty(len(y), dtype='i')
  170. for k in range(n_classes):
  171. # since the kth column of allocation stores the number of samples
  172. # of class k in each test set, this generates blocks of fold
  173. # indices corresponding to the allocation for class k.
  174. folds_for_class = np.arange(self.n_splits).repeat(allocation[:, k])
  175. if self.shuffle:
  176. rng.shuffle(folds_for_class)
  177. test_folds[y_encoded == k] = folds_for_class
  178. return test_folds
  179. def _iter_valid_test_masks(self, X, y=None, groups=None):
  180. test_folds = self._make_valid_test_folds(X, y)
  181. for i in range(self.n_splits):
  182. if i + 1 < self.n_splits:
  183. j = i + 1
  184. else:
  185. j = 0
  186. yield test_folds == i, test_folds == j
  187. def split(self, X, y, groups=None):
  188. y = check_array(y, input_name='y', ensure_2d=False, dtype=None)
  189. return super().split(X, y, groups)
  190. class _RepeatedSplitsWithValid(object):
  191. def __init__(
  192. self,
  193. cv,
  194. *,
  195. n_repeats=10,
  196. random_state=None,
  197. **cvargs
  198. ):
  199. if not isinstance(n_repeats, int):
  200. raise ValueError('Number of repetitions must be of integer type.')
  201. if n_repeats <= 0:
  202. raise ValueError('Number of repetitions must be greater than 0.')
  203. self.cv = cv
  204. self.n_repeats = n_repeats
  205. self.random_state = random_state
  206. self.cvargs = cvargs
  207. def split(self, X, y=None, groups=None):
  208. n_repeats = self.n_repeats
  209. rng = check_random_state(self.random_state)
  210. for idx in range(n_repeats):
  211. cv = self.cv(random_state=rng, shuffle=True, **self.cvargs)
  212. for train_index, valid_index, test_index in cv.split(X, y, groups):
  213. yield train_index, valid_index, test_index
  214. class RepeatedKFoldWithValid(_RepeatedSplitsWithValid):
  215. def __init__(
  216. self,
  217. *,
  218. n_splits=5,
  219. n_repeats=10,
  220. stratify=False,
  221. random_state=None
  222. ):
  223. super().__init__(
  224. KFoldWithValid,
  225. n_repeats=n_repeats,
  226. stratify=stratify,
  227. random_state=random_state,
  228. n_splits=n_splits,
  229. )

A Python package for graph kernels, graph edit distances and graph pre-image problem.