You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rng.h 7.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. #pragma once
  2. #include "megdnn/dtype.h"
  3. #include <random>
  4. #include <set>
  5. #include "test/common/random_state.h"
  6. #include "test/common/utils.h"
  7. namespace megdnn {
  8. namespace test {
  9. #if __cplusplus >= 201703L
  10. #define COMPAT_RANDOM(begin, end) \
  11. { \
  12. std::default_random_engine rng_engine; \
  13. std::shuffle(begin, end, rng_engine); \
  14. }
  15. #else
  16. #define COMPAT_RANDOM(begin, end) std::random_shuffle(begin, end);
  17. #endif
  18. class RNG {
  19. protected:
  20. class RNGxorshf;
  21. public:
  22. virtual void gen(const TensorND& tensor) = 0;
  23. virtual ~RNG() = default;
  24. };
  25. class Float16PeriodicalRNG : public RNG {
  26. public:
  27. Float16PeriodicalRNG();
  28. Float16PeriodicalRNG(size_t range);
  29. void gen(const TensorND& tensor) override;
  30. dt_float16 get_single_val();
  31. private:
  32. void gen_all_valid_float16();
  33. size_t m_offset;
  34. std::vector<dt_float16> m_sequence;
  35. };
  36. class BFloat16PeriodicalRNG : public RNG {
  37. public:
  38. BFloat16PeriodicalRNG() {
  39. size_t bits = sizeof(dt_bfloat16) * 8;
  40. size_t mantissa_bits = std::numeric_limits<dt_bfloat16>::digits - 1;
  41. size_t exponent_bits = bits - mantissa_bits - 1;
  42. for (size_t exp = 1u << (exponent_bits - 2);
  43. exp < (1u << exponent_bits) - (1u << (exponent_bits - 2)); ++exp) {
  44. for (size_t x = 0; x < 1u << mantissa_bits; ++x) {
  45. size_t pos_num = (exp << mantissa_bits) + x;
  46. size_t neg_num = (1u << (bits - 1)) + (exp << mantissa_bits) + x;
  47. union U {
  48. U() {}
  49. uint16_t i;
  50. dt_bfloat16 f;
  51. } i2f;
  52. i2f.i = static_cast<uint16_t>(pos_num);
  53. m_sequence.push_back(i2f.f);
  54. i2f.i = static_cast<uint16_t>(neg_num);
  55. m_sequence.push_back(i2f.f);
  56. }
  57. }
  58. std::shuffle(m_sequence.begin(), m_sequence.end(), RandomState::generator());
  59. }
  60. void gen(const TensorND& tensor) override {
  61. megdnn_assert(tensor.layout.dtype.enumv() == DTypeTrait<dt_bfloat16>::enumv);
  62. size_t nr_elems = tensor.layout.span().dist_elem();
  63. auto offset = tensor.layout.span().low_elem;
  64. for (size_t i = 0; i < nr_elems; ++i) {
  65. tensor.ptr<dt_bfloat16>()[offset + i] = get_single_val();
  66. }
  67. }
  68. dt_bfloat16 get_single_val() {
  69. if (m_offset >= m_sequence.size()) {
  70. m_offset = 0;
  71. }
  72. return m_sequence[m_offset++];
  73. }
  74. private:
  75. size_t m_offset = 0;
  76. std::vector<dt_bfloat16> m_sequence;
  77. };
  78. class IIDRNG : public RNG {
  79. public:
  80. void gen(const TensorND& tensor) override;
  81. virtual dt_float32 gen_single_val() = 0;
  82. virtual bool output_is_float() { return true; }
  83. protected:
  84. virtual bool has_fast_float32();
  85. virtual void fill_fast_float32(dt_float32* dest, size_t size);
  86. };
  87. class NormalRNG final : public IIDRNG {
  88. public:
  89. NormalRNG(dt_float32 mean = 0.0f, dt_float32 stddev = 1.0f)
  90. : m_dist(mean, stddev) {}
  91. void fill_fast_float32(dt_float32* dest, size_t size) override;
  92. protected:
  93. dt_float32 gen_single_val() override;
  94. private:
  95. std::normal_distribution<dt_float32> m_dist;
  96. bool has_fast_float32() override;
  97. };
  98. class ConstValue final : public IIDRNG {
  99. public:
  100. ConstValue(dt_float32 value = 0.0f) : value_(value) {}
  101. void fill_fast_float32(dt_float32* dest, size_t size) override;
  102. protected:
  103. dt_float32 gen_single_val() override { return value_; }
  104. private:
  105. dt_float32 value_;
  106. bool has_fast_float32() override { return true; }
  107. };
  108. class UniformIntRNG : public IIDRNG {
  109. public:
  110. UniformIntRNG(dt_int32 a, dt_int32 b) : m_dist(a, b) {}
  111. dt_float32 gen_single_val() override;
  112. bool output_is_float() override { return false; }
  113. protected:
  114. std::uniform_int_distribution<dt_int32> m_dist;
  115. };
  116. //! range must be positive; each value would be negated with prob 0.5
  117. class UniformIntNonZeroRNG : public UniformIntRNG {
  118. std::uniform_int_distribution<dt_int32> m_dist_flip{0, 1};
  119. public:
  120. UniformIntNonZeroRNG(int a, int b) : UniformIntRNG(a, b) {
  121. megdnn_assert(a > 0 && b > a);
  122. }
  123. dt_float32 gen_single_val() override;
  124. };
  125. class UniformFloatRNG : public IIDRNG {
  126. public:
  127. UniformFloatRNG(dt_float32 a, dt_float32 b) : m_dist(a, b) {}
  128. dt_float32 gen_single_val() override;
  129. protected:
  130. std::uniform_real_distribution<dt_float32> m_dist;
  131. bool has_fast_float32() override;
  132. void fill_fast_float32(dt_float32* dest, size_t size) override;
  133. };
  134. //! range must be positive; each value would be negated with prob 0.5
  135. class UniformFloatNonZeroRNG : public UniformFloatRNG {
  136. std::uniform_int_distribution<dt_int32> m_dist_flip{0, 1};
  137. public:
  138. UniformFloatNonZeroRNG(float a, float b) : UniformFloatRNG(a, b) {
  139. megdnn_assert(a > 0 && b > a);
  140. }
  141. dt_float32 gen_single_val() override;
  142. void fill_fast_float32(dt_float32* dest, size_t size) override;
  143. };
  144. class UniformFloatWithValueRNG : public UniformFloatRNG {
  145. public:
  146. UniformFloatWithValueRNG(
  147. dt_float32 a, dt_float32 b, float val_proportion, float val)
  148. : UniformFloatRNG(a, b), val_(val) {
  149. if (val_proportion < 0.f)
  150. val_proportion_ = 0.f;
  151. else if (val_proportion > 1.f)
  152. val_proportion_ = 1.f;
  153. else
  154. val_proportion_ = val_proportion;
  155. }
  156. private:
  157. float val_proportion_, val_;
  158. void fill_fast_float32(dt_float32* dest, size_t size) override;
  159. };
  160. class UniformFloatWithZeroRNG final : public UniformFloatWithValueRNG {
  161. public:
  162. UniformFloatWithZeroRNG(dt_float32 a, dt_float32 b, float zero_val_proportion)
  163. : UniformFloatWithValueRNG(a, b, zero_val_proportion, 0.f) {}
  164. };
  165. class BernoulliRNG final : public IIDRNG {
  166. public:
  167. BernoulliRNG(dt_float32 probability_);
  168. dt_float32 gen_single_val() override;
  169. private:
  170. dt_float32 m_probability;
  171. std::uniform_real_distribution<dt_float32> m_dist;
  172. };
  173. /**
  174. * \brief RNG without replacement, so that no two values in the tensor are
  175. * equal.
  176. *
  177. * Each value is generated repeatedly by IIDRNG, until the newly-generated value
  178. * differs from any previous value.
  179. */
  180. class NoReplacementRNG final : public RNG {
  181. private:
  182. IIDRNG* m_iid_rng;
  183. public:
  184. NoReplacementRNG(IIDRNG* iid_rng) : m_iid_rng(iid_rng) {}
  185. void gen(const TensorND& tensor) override;
  186. };
  187. //! generate a batch of matrices that are likely to have a small condition num
  188. class InvertibleMatrixRNG final : public RNG {
  189. std::unique_ptr<RNGxorshf> m_rng;
  190. public:
  191. InvertibleMatrixRNG();
  192. ~InvertibleMatrixRNG() noexcept;
  193. void gen(const TensorND& tensor) override;
  194. private:
  195. template <typename ctype>
  196. void do_gen(ctype* ptr, size_t batch, size_t n);
  197. };
  198. //! generate a continuous number of delta, start from value
  199. class ConsecutiveRNG final : public IIDRNG {
  200. public:
  201. ConsecutiveRNG(dt_float32 value = 0.0f, dt_float32 delta = 1.0f)
  202. : value_(value), delta_(delta) {}
  203. void fill_fast_float32(dt_float32* dest, size_t size) override;
  204. protected:
  205. dt_float32 gen_single_val() override {
  206. auto res = value_;
  207. value_ += delta_;
  208. return res;
  209. }
  210. private:
  211. dt_float32 value_, delta_;
  212. bool has_fast_float32() override { return true; }
  213. };
  214. } // namespace test
  215. } // namespace megdnn
  216. // vim: syntax=cpp.doxygen