You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. #include "test/common/rng.h"
  2. #include <gtest/gtest.h>
  3. #include "test/common/random_state.h"
  4. #include "test/common/tensor.h"
  5. using namespace megdnn;
  6. using namespace test;
  7. /*!
  8. * \brief xorshift+ RNG, which is very fast
  9. *
  10. * see https://en.wikipedia.org/wiki/Xorshift#xorshift.2B
  11. */
  12. class RNG::RNGxorshf {
  13. uint64_t s[2];
  14. public:
  15. using result_type = uint64_t;
  16. #ifdef WIN32
  17. static uint64_t min() { return 0; }
  18. static uint64_t max() { return std::numeric_limits<uint64_t>::max(); }
  19. #else
  20. static constexpr uint64_t min() { return 0; }
  21. static constexpr uint64_t max() { return std::numeric_limits<uint64_t>::max(); }
  22. #endif
  23. template <typename T>
  24. explicit RNGxorshf(T&& gen) {
  25. s[0] = gen();
  26. s[1] = gen();
  27. }
  28. uint64_t operator()() {
  29. uint64_t x = s[0];
  30. uint64_t const y = s[1];
  31. s[0] = y;
  32. x ^= x << 23; // a
  33. s[1] = x ^ y ^ (x >> 17) ^ (y >> 26); // b, c
  34. return s[1] + y;
  35. }
  36. };
  37. Float16PeriodicalRNG::Float16PeriodicalRNG() : m_offset(0) {
  38. for (size_t x = 0; x < (1u << 16); ++x) {
  39. size_t exponent = (x >> 10) & 0x1F;
  40. if (exponent == 0x1F) {
  41. // +inf, -inf, NaN
  42. continue;
  43. }
  44. union U {
  45. U() {}
  46. uint16_t i;
  47. dt_float16 f;
  48. } i2f;
  49. i2f.i = static_cast<uint16_t>(x);
  50. m_sequence.push_back(i2f.f);
  51. }
  52. COMPAT_RANDOM(m_sequence.begin(), m_sequence.end());
  53. }
  54. Float16PeriodicalRNG::Float16PeriodicalRNG(size_t range) : m_offset(0) {
  55. union U {
  56. U() {}
  57. uint16_t i;
  58. dt_float16 f;
  59. } i2f;
  60. size_t x = 0;
  61. i2f.i = static_cast<uint16_t>(x);
  62. for (size_t i = 0; i < range; i++) {
  63. x += 1;
  64. i2f.i = static_cast<uint16_t>(x);
  65. m_sequence.push_back(i2f.f);
  66. }
  67. x = 1u << 15;
  68. i2f.i = static_cast<uint16_t>(x);
  69. for (size_t i = 0; i < range; i++) {
  70. x += 1;
  71. i2f.i = static_cast<uint16_t>(x);
  72. m_sequence.push_back(i2f.f);
  73. }
  74. COMPAT_RANDOM(m_sequence.begin(), m_sequence.end());
  75. }
  76. void Float16PeriodicalRNG::gen(const TensorND& tensor) {
  77. megdnn_assert(tensor.layout.dtype == dtype::Float16());
  78. size_t nr_elems = tensor.layout.span().dist_elem();
  79. auto offset = tensor.layout.span().low_elem;
  80. for (size_t i = 0; i < nr_elems; ++i) {
  81. tensor.ptr<dt_float16>()[offset + i] = get_single_val();
  82. }
  83. }
  84. dt_float16 Float16PeriodicalRNG::get_single_val() {
  85. if (m_offset >= m_sequence.size()) {
  86. m_offset = 0;
  87. }
  88. return m_sequence[m_offset++];
  89. }
  90. void IIDRNG::gen(const TensorND& tensor) {
  91. if (tensor.layout.dtype == dtype::Float32() && has_fast_float32() &&
  92. tensor.layout.is_physical_contiguous()) {
  93. fill_fast_float32(tensor.ptr<dt_float32>(), tensor.layout.total_nr_elems());
  94. return;
  95. }
  96. auto offset = tensor.layout.span().low_elem;
  97. auto nr_elems = tensor.layout.span().dist_elem();
  98. #define cb(DType) \
  99. if (tensor.layout.dtype == DType()) { \
  100. using ctype = typename DTypeTrait<DType>::ctype; \
  101. auto ptr = tensor.ptr<ctype>(); \
  102. for (size_t i = 0; i < nr_elems; ++i) { \
  103. ptr[offset + i] = static_cast<ctype>(gen_single_val()); \
  104. } \
  105. return; \
  106. }
  107. MEGDNN_FOREACH_COMPUTING_DTYPE(cb);
  108. #undef cb
  109. #define cb(DType) \
  110. if (tensor.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \
  111. using ctype = typename DTypeTrait<DType>::ctype; \
  112. auto ptr = tensor.ptr<ctype>(); \
  113. if (output_is_float()) { \
  114. for (size_t i = 0; i < nr_elems; ++i) { \
  115. ptr[offset + i] = tensor.layout.dtype.param<DType>().quantize( \
  116. static_cast<float>(gen_single_val())); \
  117. } \
  118. } else { \
  119. for (size_t i = 0; i < nr_elems; ++i) { \
  120. ptr[offset + i] = static_cast<ctype>(gen_single_val()); \
  121. } \
  122. } \
  123. return; \
  124. }
  125. MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
  126. //! In order to avoid an unnecessary increase in binary size, we just
  127. //! use QuantizedS16 dtype in winograd_filter_preprocess now.
  128. cb(::megdnn::dtype::QuantizedS16) cb(::megdnn::dtype::QuantizedS1)
  129. #undef cb
  130. if (tensor.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) {
  131. auto ptr = static_cast<uint8_t*>(tensor.raw_ptr());
  132. if (output_is_float()) {
  133. for (size_t i = 0; i < nr_elems; i += 2) {
  134. uint8_t val0 = tensor.layout.dtype.param<dt_quint4>()
  135. .quantize(static_cast<float>(gen_single_val()))
  136. .as_uint8();
  137. uint8_t val1 = tensor.layout.dtype.param<dt_quint4>()
  138. .quantize(static_cast<float>(gen_single_val()))
  139. .as_uint8();
  140. ptr[(offset + i) / 2] = (val1 << 4) | val0;
  141. }
  142. } else {
  143. for (size_t i = 0; i < nr_elems; i += 2) {
  144. uint8_t val0 = static_cast<uint8_t>(gen_single_val());
  145. uint8_t val1 = static_cast<uint8_t>(gen_single_val());
  146. ptr[(offset + i) / 2] = (val1 << 4) | val0;
  147. }
  148. }
  149. return;
  150. }
  151. if (tensor.layout.dtype.enumv() == DTypeEnum::QuantizedS4) {
  152. auto ptr = static_cast<int8_t*>(tensor.raw_ptr());
  153. if (output_is_float()) {
  154. for (size_t i = 0; i < nr_elems; i += 2) {
  155. int8_t val0 = tensor.layout.dtype.param<dt_qint4>()
  156. .quantize(static_cast<float>(gen_single_val()))
  157. .as_int8();
  158. int8_t val1 = tensor.layout.dtype.param<dt_qint4>()
  159. .quantize(static_cast<float>(gen_single_val()))
  160. .as_int8();
  161. ptr[(offset + i) / 2] = (val0 & 0xF) | (val1 << 4);
  162. }
  163. } else {
  164. for (size_t i = 0; i < nr_elems; i += 2) {
  165. int8_t val0 = static_cast<int8_t>(gen_single_val());
  166. int8_t val1 = static_cast<int8_t>(gen_single_val());
  167. val0 = std::min(val0, DTypeTrait<dtype::QuantizedS4>::max());
  168. val0 = std::max(val0, DTypeTrait<dtype::QuantizedS4>::min());
  169. val1 = std::min(val1, DTypeTrait<dtype::QuantizedS4>::max());
  170. val1 = std::max(val1, DTypeTrait<dtype::QuantizedS4>::min());
  171. ptr[(offset + i) / 2] = (val0 & 0xF) | (val1 << 4);
  172. }
  173. }
  174. return;
  175. }
  176. if (tensor.layout.dtype.enumv() == DTypeEnum::Byte) {
  177. memset(tensor.raw_ptr(), 0, tensor.layout.access_bytes());
  178. return;
  179. }
  180. if (tensor.layout.dtype.enumv() == DTypeEnum::Uint16) {
  181. return;
  182. }
  183. megdnn_assert(
  184. 0, "IIDRNG does not know how to generate value for DType %s",
  185. tensor.layout.dtype.name());
  186. }
  187. bool IIDRNG::has_fast_float32() {
  188. return false;
  189. }
  190. void IIDRNG::fill_fast_float32(dt_float32*, size_t) {
  191. megdnn_assert(0);
  192. }
  193. dt_float32 NormalRNG::gen_single_val() {
  194. auto&& gen = RandomState::generator();
  195. return m_dist(gen);
  196. }
  197. bool NormalRNG::has_fast_float32() {
  198. return true;
  199. }
  200. void NormalRNG::fill_fast_float32(dt_float32* dest, size_t size) {
  201. RNGxorshf gen{RandomState::generator()};
  202. for (size_t i = 0; i < size; ++i) {
  203. dest[i] = m_dist(gen);
  204. }
  205. }
  206. void ConstValue::fill_fast_float32(dt_float32* dest, size_t size) {
  207. for (size_t i = 0; i < size; ++i)
  208. dest[i] = value_;
  209. }
  210. dt_float32 UniformIntRNG::gen_single_val() {
  211. auto&& gen = RandomState::generator();
  212. return static_cast<dt_float32>(m_dist(gen));
  213. }
  214. dt_float32 UniformIntNonZeroRNG::gen_single_val() {
  215. auto&& gen = RandomState::generator();
  216. auto ret = UniformIntRNG::gen_single_val();
  217. if (m_dist_flip(gen)) {
  218. ret = -ret;
  219. }
  220. megdnn_assert(ret != 0);
  221. return ret;
  222. }
  223. dt_float32 UniformFloatRNG::gen_single_val() {
  224. auto&& gen = RandomState::generator();
  225. return m_dist(gen);
  226. }
  227. bool UniformFloatRNG::has_fast_float32() {
  228. return true;
  229. }
  230. void UniformFloatRNG::fill_fast_float32(dt_float32* dest, size_t size) {
  231. RNGxorshf gen{RandomState::generator()};
  232. auto k = double(m_dist.b() - m_dist.a()) /
  233. double(RNGxorshf::max() - RNGxorshf::min() + 1.0);
  234. auto b = m_dist.a() - RNGxorshf::min() * k;
  235. for (size_t i = 0; i < size; ++i) {
  236. dest[i] = gen() * k + b;
  237. }
  238. }
  239. dt_float32 UniformFloatNonZeroRNG::gen_single_val() {
  240. auto&& gen = RandomState::generator();
  241. auto ret = UniformFloatRNG::gen_single_val();
  242. if (m_dist_flip(gen)) {
  243. ret = -ret;
  244. }
  245. megdnn_assert(ret != 0);
  246. return ret;
  247. }
  248. void UniformFloatNonZeroRNG::fill_fast_float32(dt_float32* dest, size_t size) {
  249. RNGxorshf gen{RandomState::generator()};
  250. UniformFloatRNG::fill_fast_float32(dest, size);
  251. for (size_t i = 0; i < size; ++i) {
  252. if (m_dist_flip(gen)) {
  253. dest[i] = -dest[i];
  254. }
  255. }
  256. }
  257. void UniformFloatWithValueRNG::fill_fast_float32(dt_float32* dest, size_t size) {
  258. RNGxorshf gen{RandomState::generator()};
  259. auto k = double(m_dist.b() - m_dist.a()) /
  260. double(RNGxorshf::max() - RNGxorshf::min() + 1.0);
  261. auto b = m_dist.a() - RNGxorshf::min() * k;
  262. auto p = 1.0 / double(RNGxorshf::max() - RNGxorshf::min() + 1.0);
  263. auto pb = 0.f - RNGxorshf::min() * p;
  264. for (size_t i = 0; i < size; ++i) {
  265. float rnd = gen() * p + pb;
  266. if (rnd < val_proportion_) {
  267. dest[i] = val_;
  268. } else {
  269. dest[i] = gen() * k + b;
  270. }
  271. }
  272. }
  273. BernoulliRNG::BernoulliRNG(float probability_) : m_dist(0, 1) {
  274. megdnn_assert(0.0f <= probability_ && probability_ < 1.0f);
  275. m_probability = probability_;
  276. }
  277. dt_float32 BernoulliRNG::gen_single_val() {
  278. auto&& gen = RandomState::generator();
  279. return m_dist(gen) < m_probability ? 1.0 : 0.0;
  280. }
  281. void NoReplacementRNG::gen(const TensorND& tensor) {
  282. auto offset = tensor.layout.span().low_elem;
  283. auto nr_elems = tensor.layout.span().dist_elem();
  284. #define cb(DType) \
  285. if (tensor.layout.dtype == DType()) { \
  286. using ctype = typename DTypeTrait<DType>::ctype; \
  287. std::set<ctype> values; \
  288. auto ptr = tensor.ptr<ctype>(); \
  289. for (size_t i = 0; i < nr_elems; ++i) { \
  290. ctype val; \
  291. do { \
  292. val = static_cast<ctype>(m_iid_rng->gen_single_val()); \
  293. } while (!values.insert(val).second); \
  294. ptr[offset + i] = val; \
  295. } \
  296. }
  297. MEGDNN_FOREACH_COMPUTING_DTYPE(cb);
  298. #undef cb
  299. }
  300. InvertibleMatrixRNG::InvertibleMatrixRNG()
  301. : m_rng{new RNGxorshf{RandomState::generator()}} {}
  302. InvertibleMatrixRNG::~InvertibleMatrixRNG() noexcept = default;
  303. template <typename ctype>
  304. void InvertibleMatrixRNG::do_gen(ctype* ptr, size_t batch, size_t n) {
  305. auto&& gen = *m_rng;
  306. std::vector<size_t> perm(n);
  307. for (size_t i = 0; i < n; ++i) {
  308. perm[i] = i;
  309. }
  310. for (size_t i = 0; i < batch; ++i, ptr += n * n) {
  311. for (size_t j = 0; j < n; ++j) {
  312. for (size_t k = 0; k < n; ++k) {
  313. ptr[j * n + k] =
  314. static_cast<ctype>(gen() / (RNGxorshf::max() + 1.0) * 2 - 0.5);
  315. }
  316. }
  317. for (size_t i = 0; i < n; ++i) {
  318. auto idx = gen() % (n - i) + i;
  319. ptr[i * n + perm[idx]] +=
  320. static_cast<ctype>(gen() / (RNGxorshf::max() + 1.0) + 3);
  321. std::swap(perm[idx], perm[i]);
  322. }
  323. }
  324. }
  325. void InvertibleMatrixRNG::gen(const TensorND& tensor) {
  326. #define cb(DType) \
  327. if (tensor.layout.dtype == DType()) { \
  328. using ctype = typename DTypeTrait<DType>::ctype; \
  329. auto ptr = tensor.ptr<ctype>(); \
  330. megdnn_assert( \
  331. tensor.layout.ndim >= 2 && tensor.layout.is_physical_contiguous()); \
  332. size_t batch = 1; \
  333. for (size_t i = 0; i < tensor.layout.ndim - 2; ++i) { \
  334. batch *= tensor.layout[i]; \
  335. } \
  336. size_t n = tensor.layout[tensor.layout.ndim - 1]; \
  337. megdnn_assert(n == tensor.layout[tensor.layout.ndim - 2]); \
  338. do_gen<ctype>(ptr, batch, n); \
  339. return; \
  340. }
  341. MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb)
  342. #undef cb
  343. }
  344. void ConsecutiveRNG::fill_fast_float32(dt_float32* dest, size_t size) {
  345. for (size_t i = 0; i < size; ++i)
  346. dest[i] = value_ + i * delta_;
  347. }
  348. TEST(RNG, NO_REPLACEMENT_RNG) {
  349. static const size_t N = 10, TIMES = 100;
  350. UniformIntRNG base_rng(0, N - 1);
  351. NoReplacementRNG rng(&base_rng);
  352. auto handle = create_cpu_handle(2, false);
  353. for (size_t t = 0; t < TIMES; ++t) {
  354. TensorLayout layout({N}, dtype::Float32());
  355. Tensor<> tensor(handle.get(), layout);
  356. rng.gen(tensor.tensornd());
  357. std::vector<float> vals;
  358. for (size_t i = 0; i < N; ++i)
  359. vals.push_back(tensor.ptr()[i]);
  360. std::sort(vals.begin(), vals.end());
  361. for (size_t i = 0; i < N; ++i)
  362. ASSERT_EQ(static_cast<float>(i), vals[i]);
  363. }
  364. }
  365. // vim: syntax=cpp.doxygen