You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rng.h 7.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. /**
  2. * \file dnn/test/common/rng.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/dtype.h"
  13. #include "test/common/utils.h"
  14. #include "test/common/random_state.h"
  15. #include <random>
  16. #include <set>
  17. namespace megdnn {
  18. namespace test {
  19. class RNG {
  20. protected:
  21. class RNGxorshf;
  22. public:
  23. virtual void gen(const TensorND& tensor) = 0;
  24. virtual ~RNG() = default;
  25. };
  26. class Float16PeriodicalRNG : public RNG {
  27. public:
  28. Float16PeriodicalRNG();
  29. Float16PeriodicalRNG(size_t range);
  30. void gen(const TensorND& tensor) override;
  31. dt_float16 get_single_val();
  32. private:
  33. void gen_all_valid_float16();
  34. size_t m_offset;
  35. std::vector<dt_float16> m_sequence;
  36. };
  37. class BFloat16PeriodicalRNG : public RNG {
  38. public:
  39. BFloat16PeriodicalRNG() {
  40. size_t bits = sizeof(dt_bfloat16) * 8;
  41. size_t mantissa_bits = std::numeric_limits<dt_bfloat16>::digits - 1;
  42. size_t exponent_bits = bits - mantissa_bits - 1;
  43. for (size_t exp = 1u << (exponent_bits - 2);
  44. exp < (1u << exponent_bits) - (1u << (exponent_bits - 2)); ++exp) {
  45. for (size_t x = 0; x < 1u << mantissa_bits; ++x) {
  46. size_t pos_num = (exp << mantissa_bits) + x;
  47. size_t neg_num =
  48. (1u << (bits - 1)) + (exp << mantissa_bits) + x;
  49. union U {
  50. U() {}
  51. uint16_t i;
  52. dt_bfloat16 f;
  53. } i2f;
  54. i2f.i = static_cast<uint16_t>(pos_num);
  55. m_sequence.push_back(i2f.f);
  56. i2f.i = static_cast<uint16_t>(neg_num);
  57. m_sequence.push_back(i2f.f);
  58. }
  59. }
  60. std::shuffle(m_sequence.begin(), m_sequence.end(),
  61. RandomState::generator());
  62. }
  63. void gen(const TensorND& tensor) override {
  64. megdnn_assert(tensor.layout.dtype.enumv() == DTypeTrait<dt_bfloat16>::enumv);
  65. size_t nr_elems = tensor.layout.span().dist_elem();
  66. auto offset = tensor.layout.span().low_elem;
  67. for (size_t i = 0; i < nr_elems; ++i) {
  68. tensor.ptr<dt_bfloat16>()[offset + i] = get_single_val();
  69. }
  70. }
  71. dt_bfloat16 get_single_val() {
  72. if (m_offset >= m_sequence.size()) {
  73. m_offset = 0;
  74. }
  75. return m_sequence[m_offset++];
  76. }
  77. private:
  78. size_t m_offset = 0;
  79. std::vector<dt_bfloat16> m_sequence;
  80. };
  81. class IIDRNG : public RNG {
  82. public:
  83. void gen(const TensorND& tensor) override;
  84. virtual dt_float32 gen_single_val() = 0;
  85. virtual bool output_is_float() { return true; }
  86. protected:
  87. virtual bool has_fast_float32();
  88. virtual void fill_fast_float32(dt_float32* dest, size_t size);
  89. };
  90. class NormalRNG final : public IIDRNG {
  91. public:
  92. NormalRNG(dt_float32 mean = 0.0f, dt_float32 stddev = 1.0f)
  93. : m_dist(mean, stddev) {}
  94. void fill_fast_float32(dt_float32* dest, size_t size) override;
  95. protected:
  96. dt_float32 gen_single_val() override;
  97. private:
  98. std::normal_distribution<dt_float32> m_dist;
  99. bool has_fast_float32() override;
  100. };
  101. class ConstValue final : public IIDRNG {
  102. public:
  103. ConstValue(dt_float32 value = 0.0f) : value_(value) {}
  104. void fill_fast_float32(dt_float32* dest, size_t size) override;
  105. protected:
  106. dt_float32 gen_single_val() override { return value_; }
  107. private:
  108. dt_float32 value_;
  109. bool has_fast_float32() override { return true; }
  110. };
  111. class UniformIntRNG : public IIDRNG {
  112. public:
  113. UniformIntRNG(dt_int32 a, dt_int32 b) : m_dist(a, b) {}
  114. dt_float32 gen_single_val() override;
  115. bool output_is_float() override { return false; }
  116. protected:
  117. std::uniform_int_distribution<dt_int32> m_dist;
  118. };
  119. //! range must be positive; each value would be negated with prob 0.5
  120. class UniformIntNonZeroRNG : public UniformIntRNG {
  121. std::uniform_int_distribution<dt_int32> m_dist_flip{0, 1};
  122. public:
  123. UniformIntNonZeroRNG(int a, int b) : UniformIntRNG(a, b) {
  124. megdnn_assert(a > 0 && b > a);
  125. }
  126. dt_float32 gen_single_val() override;
  127. };
  128. class UniformFloatRNG : public IIDRNG {
  129. public:
  130. UniformFloatRNG(dt_float32 a, dt_float32 b) : m_dist(a, b) {}
  131. dt_float32 gen_single_val() override;
  132. protected:
  133. std::uniform_real_distribution<dt_float32> m_dist;
  134. bool has_fast_float32() override;
  135. void fill_fast_float32(dt_float32* dest, size_t size) override;
  136. };
  137. //! range must be positive; each value would be negated with prob 0.5
  138. class UniformFloatNonZeroRNG : public UniformFloatRNG {
  139. std::uniform_int_distribution<dt_int32> m_dist_flip{0, 1};
  140. public:
  141. UniformFloatNonZeroRNG(float a, float b) : UniformFloatRNG(a, b) {
  142. megdnn_assert(a > 0 && b > a);
  143. }
  144. dt_float32 gen_single_val() override;
  145. void fill_fast_float32(dt_float32* dest, size_t size) override;
  146. };
  147. class UniformFloatWithZeroRNG final : public UniformFloatRNG {
  148. public:
  149. UniformFloatWithZeroRNG(dt_float32 a, dt_float32 b,
  150. float zero_val_proportion)
  151. : UniformFloatRNG(a, b) {
  152. if (zero_val_proportion < 0.f)
  153. zero_val_proportion_ = 0.f;
  154. else if (zero_val_proportion > 1.f)
  155. zero_val_proportion_ = 1.f;
  156. else
  157. zero_val_proportion_ = zero_val_proportion;
  158. }
  159. private:
  160. float zero_val_proportion_;
  161. void fill_fast_float32(dt_float32* dest, size_t size) override;
  162. };
  163. class BernoulliRNG final : public IIDRNG {
  164. public:
  165. BernoulliRNG(dt_float32 probability_);
  166. dt_float32 gen_single_val() override;
  167. private:
  168. dt_float32 m_probability;
  169. std::uniform_real_distribution<dt_float32> m_dist;
  170. };
  171. /**
  172. * \brief RNG without replacement, so that no two values in the tensor are
  173. * equal.
  174. *
  175. * Each value is generated repeatedly by IIDRNG, until the newly-generated value
  176. * differs from any previous value.
  177. */
  178. class NoReplacementRNG final : public RNG {
  179. private:
  180. IIDRNG* m_iid_rng;
  181. public:
  182. NoReplacementRNG(IIDRNG* iid_rng) : m_iid_rng(iid_rng) {}
  183. void gen(const TensorND& tensor) override;
  184. };
  185. //! generate a batch of matrices that are likely to have a small condition num
  186. class InvertibleMatrixRNG final : public RNG {
  187. std::unique_ptr<RNGxorshf> m_rng;
  188. public:
  189. InvertibleMatrixRNG();
  190. ~InvertibleMatrixRNG() noexcept;
  191. void gen(const TensorND& tensor) override;
  192. private:
  193. template <typename ctype>
  194. void do_gen(ctype* ptr, size_t batch, size_t n);
  195. };
  196. //! generate a continuous number of delta, start from value
  197. class ConsecutiveRNG final : public IIDRNG {
  198. public:
  199. ConsecutiveRNG(dt_float32 value = 0.0f, dt_float32 delta = 1.0f)
  200. : value_(value), delta_(delta) {}
  201. void fill_fast_float32(dt_float32* dest, size_t size) override;
  202. protected:
  203. dt_float32 gen_single_val() override {
  204. auto res = value_;
  205. value_ += delta_;
  206. return res;
  207. }
  208. private:
  209. dt_float32 value_, delta_;
  210. bool has_fast_float32() override { return true; }
  211. };
  212. } // namespace test
  213. } // namespace megdnn
  214. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台