You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. #include "test/common/rng.h"
  2. #include <gtest/gtest.h>
  3. #include "test/common/random_state.h"
  4. #include "test/common/tensor.h"
  5. using namespace megdnn;
  6. using namespace test;
  7. /*!
  8. * \brief xorshift+ RNG, which is very fast
  9. *
  10. * see https://en.wikipedia.org/wiki/Xorshift#xorshift.2B
  11. */
  12. class RNG::RNGxorshf {
  13. uint64_t s[2];
  14. public:
  15. using result_type = uint64_t;
  16. #ifdef WIN32
  17. static uint64_t min() { return 0; }
  18. static uint64_t max() { return std::numeric_limits<uint64_t>::max(); }
  19. #else
  20. static constexpr uint64_t min() { return 0; }
  21. static constexpr uint64_t max() { return std::numeric_limits<uint64_t>::max(); }
  22. #endif
  23. template <typename T>
  24. explicit RNGxorshf(T&& gen) {
  25. s[0] = gen();
  26. s[1] = gen();
  27. }
  28. uint64_t operator()() {
  29. uint64_t x = s[0];
  30. uint64_t const y = s[1];
  31. s[0] = y;
  32. x ^= x << 23; // a
  33. s[1] = x ^ y ^ (x >> 17) ^ (y >> 26); // b, c
  34. return s[1] + y;
  35. }
  36. };
  37. Float16PeriodicalRNG::Float16PeriodicalRNG() : m_offset(0) {
  38. for (size_t x = 0; x < (1u << 16); ++x) {
  39. size_t exponent = (x >> 10) & 0x1F;
  40. if (exponent == 0x1F) {
  41. // +inf, -inf, NaN
  42. continue;
  43. }
  44. union U {
  45. U() {}
  46. uint16_t i;
  47. dt_float16 f;
  48. } i2f;
  49. i2f.i = static_cast<uint16_t>(x);
  50. m_sequence.push_back(i2f.f);
  51. }
  52. COMPAT_RANDOM(m_sequence.begin(), m_sequence.end());
  53. }
  54. Float16PeriodicalRNG::Float16PeriodicalRNG(size_t range) : m_offset(0) {
  55. union U {
  56. U() {}
  57. uint16_t i;
  58. dt_float16 f;
  59. } i2f;
  60. size_t x = 0;
  61. i2f.i = static_cast<uint16_t>(x);
  62. for (size_t i = 0; i < range; i++) {
  63. x += 1;
  64. i2f.i = static_cast<uint16_t>(x);
  65. m_sequence.push_back(i2f.f);
  66. }
  67. x = 1u << 15;
  68. i2f.i = static_cast<uint16_t>(x);
  69. for (size_t i = 0; i < range; i++) {
  70. x += 1;
  71. i2f.i = static_cast<uint16_t>(x);
  72. m_sequence.push_back(i2f.f);
  73. }
  74. COMPAT_RANDOM(m_sequence.begin(), m_sequence.end());
  75. }
  76. void Float16PeriodicalRNG::gen(const TensorND& tensor) {
  77. megdnn_assert(tensor.layout.dtype == dtype::Float16());
  78. size_t nr_elems = tensor.layout.span().dist_elem();
  79. auto offset = tensor.layout.span().low_elem;
  80. for (size_t i = 0; i < nr_elems; ++i) {
  81. tensor.ptr<dt_float16>()[offset + i] = get_single_val();
  82. }
  83. }
  84. dt_float16 Float16PeriodicalRNG::get_single_val() {
  85. if (m_offset >= m_sequence.size()) {
  86. m_offset = 0;
  87. }
  88. return m_sequence[m_offset++];
  89. }
  90. void IIDRNG::gen(const TensorND& tensor) {
  91. if (tensor.layout.dtype == dtype::Float32() && has_fast_float32() &&
  92. tensor.layout.is_physical_contiguous()) {
  93. fill_fast_float32(tensor.ptr<dt_float32>(), tensor.layout.total_nr_elems());
  94. return;
  95. }
  96. auto offset = tensor.layout.span().low_elem;
  97. auto nr_elems = tensor.layout.span().dist_elem();
  98. #define cb(DType) \
  99. if (tensor.layout.dtype == DType()) { \
  100. using ctype = typename DTypeTrait<DType>::ctype; \
  101. auto ptr = tensor.ptr<ctype>(); \
  102. for (size_t i = 0; i < nr_elems; ++i) { \
  103. ptr[offset + i] = static_cast<ctype>(gen_single_val()); \
  104. } \
  105. return; \
  106. }
  107. MEGDNN_FOREACH_COMPUTING_DTYPE(cb);
  108. #undef cb
  109. #define cb(DType) \
  110. if (tensor.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \
  111. using ctype = typename DTypeTrait<DType>::ctype; \
  112. auto ptr = tensor.ptr<ctype>(); \
  113. if (output_is_float()) { \
  114. for (size_t i = 0; i < nr_elems; ++i) { \
  115. ptr[offset + i] = tensor.layout.dtype.param<DType>().quantize( \
  116. static_cast<float>(gen_single_val())); \
  117. } \
  118. } else { \
  119. for (size_t i = 0; i < nr_elems; ++i) { \
  120. ptr[offset + i] = static_cast<ctype>(gen_single_val()); \
  121. } \
  122. } \
  123. return; \
  124. }
  125. MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
  126. //! In order to avoid an unnecessary increase in binary size, we just
  127. //! use QuantizedS16 dtype in winograd_filter_preprocess now.
  128. cb(::megdnn::dtype::QuantizedS16) cb(::megdnn::dtype::QuantizedS1)
  129. #undef cb
  130. if (tensor.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) {
  131. auto ptr = static_cast<uint8_t*>(tensor.raw_ptr());
  132. if (output_is_float()) {
  133. for (size_t i = 0; i < nr_elems; i += 2) {
  134. uint8_t val0 = tensor.layout.dtype.param<dt_quint4>()
  135. .quantize(static_cast<float>(gen_single_val()))
  136. .as_uint8();
  137. uint8_t val1 = tensor.layout.dtype.param<dt_quint4>()
  138. .quantize(static_cast<float>(gen_single_val()))
  139. .as_uint8();
  140. ptr[(offset + i) / 2] = (val1 << 4) | val0;
  141. }
  142. } else {
  143. for (size_t i = 0; i < nr_elems; i += 2) {
  144. uint8_t val0 = static_cast<uint8_t>(gen_single_val());
  145. uint8_t val1 = static_cast<uint8_t>(gen_single_val());
  146. ptr[(offset + i) / 2] = (val1 << 4) | val0;
  147. }
  148. }
  149. return;
  150. }
  151. if (tensor.layout.dtype.enumv() == DTypeEnum::QuantizedS4) {
  152. auto ptr = static_cast<int8_t*>(tensor.raw_ptr());
  153. if (output_is_float()) {
  154. for (size_t i = 0; i < nr_elems; i += 2) {
  155. int8_t val0 = tensor.layout.dtype.param<dt_qint4>()
  156. .quantize(static_cast<float>(gen_single_val()))
  157. .as_int8();
  158. int8_t val1 = tensor.layout.dtype.param<dt_qint4>()
  159. .quantize(static_cast<float>(gen_single_val()))
  160. .as_int8();
  161. ptr[(offset + i) / 2] = (val0 & 0xF) | (val1 << 4);
  162. }
  163. } else {
  164. for (size_t i = 0; i < nr_elems; i += 2) {
  165. int8_t val0 = static_cast<int8_t>(gen_single_val());
  166. int8_t val1 = static_cast<int8_t>(gen_single_val());
  167. val0 = std::min(val0, DTypeTrait<dtype::QuantizedS4>::max());
  168. val0 = std::max(val0, DTypeTrait<dtype::QuantizedS4>::min());
  169. val1 = std::min(val1, DTypeTrait<dtype::QuantizedS4>::max());
  170. val1 = std::max(val1, DTypeTrait<dtype::QuantizedS4>::min());
  171. ptr[(offset + i) / 2] = (val0 & 0xF) | (val1 << 4);
  172. }
  173. }
  174. return;
  175. }
  176. if (tensor.layout.dtype.enumv() == DTypeEnum::Byte) {
  177. memset(tensor.raw_ptr(), 0, tensor.layout.access_bytes());
  178. return;
  179. }
  180. if (tensor.layout.dtype.enumv() == DTypeEnum::Uint16) {
  181. return;
  182. }
  183. if (tensor.layout.dtype.enumv() == DTypeEnum::Bool) {
  184. return;
  185. }
  186. megdnn_assert(
  187. 0, "IIDRNG does not know how to generate value for DType %s",
  188. tensor.layout.dtype.name());
  189. }
  190. bool IIDRNG::has_fast_float32() {
  191. return false;
  192. }
  193. void IIDRNG::fill_fast_float32(dt_float32*, size_t) {
  194. megdnn_assert(0);
  195. }
  196. dt_float32 NormalRNG::gen_single_val() {
  197. auto&& gen = RandomState::generator();
  198. return m_dist(gen);
  199. }
  200. bool NormalRNG::has_fast_float32() {
  201. return true;
  202. }
  203. void NormalRNG::fill_fast_float32(dt_float32* dest, size_t size) {
  204. RNGxorshf gen{RandomState::generator()};
  205. for (size_t i = 0; i < size; ++i) {
  206. dest[i] = m_dist(gen);
  207. }
  208. }
  209. void ConstValue::fill_fast_float32(dt_float32* dest, size_t size) {
  210. for (size_t i = 0; i < size; ++i)
  211. dest[i] = value_;
  212. }
  213. dt_float32 UniformIntRNG::gen_single_val() {
  214. auto&& gen = RandomState::generator();
  215. return static_cast<dt_float32>(m_dist(gen));
  216. }
  217. dt_float32 UniformIntNonZeroRNG::gen_single_val() {
  218. auto&& gen = RandomState::generator();
  219. auto ret = UniformIntRNG::gen_single_val();
  220. if (m_dist_flip(gen)) {
  221. ret = -ret;
  222. }
  223. megdnn_assert(ret != 0);
  224. return ret;
  225. }
  226. dt_float32 UniformFloatRNG::gen_single_val() {
  227. auto&& gen = RandomState::generator();
  228. return m_dist(gen);
  229. }
  230. bool UniformFloatRNG::has_fast_float32() {
  231. return true;
  232. }
  233. void UniformFloatRNG::fill_fast_float32(dt_float32* dest, size_t size) {
  234. RNGxorshf gen{RandomState::generator()};
  235. auto k = double(m_dist.b() - m_dist.a()) /
  236. double(RNGxorshf::max() - RNGxorshf::min() + 1.0);
  237. auto b = m_dist.a() - RNGxorshf::min() * k;
  238. for (size_t i = 0; i < size; ++i) {
  239. dest[i] = gen() * k + b;
  240. }
  241. }
  242. dt_float32 UniformFloatNonZeroRNG::gen_single_val() {
  243. auto&& gen = RandomState::generator();
  244. auto ret = UniformFloatRNG::gen_single_val();
  245. if (m_dist_flip(gen)) {
  246. ret = -ret;
  247. }
  248. megdnn_assert(ret != 0);
  249. return ret;
  250. }
  251. void UniformFloatNonZeroRNG::fill_fast_float32(dt_float32* dest, size_t size) {
  252. RNGxorshf gen{RandomState::generator()};
  253. UniformFloatRNG::fill_fast_float32(dest, size);
  254. for (size_t i = 0; i < size; ++i) {
  255. if (m_dist_flip(gen)) {
  256. dest[i] = -dest[i];
  257. }
  258. }
  259. }
  260. void UniformFloatWithValueRNG::fill_fast_float32(dt_float32* dest, size_t size) {
  261. RNGxorshf gen{RandomState::generator()};
  262. auto k = double(m_dist.b() - m_dist.a()) /
  263. double(RNGxorshf::max() - RNGxorshf::min() + 1.0);
  264. auto b = m_dist.a() - RNGxorshf::min() * k;
  265. auto p = 1.0 / double(RNGxorshf::max() - RNGxorshf::min() + 1.0);
  266. auto pb = 0.f - RNGxorshf::min() * p;
  267. for (size_t i = 0; i < size; ++i) {
  268. float rnd = gen() * p + pb;
  269. if (rnd < val_proportion_) {
  270. dest[i] = val_;
  271. } else {
  272. dest[i] = gen() * k + b;
  273. }
  274. }
  275. }
  276. BernoulliRNG::BernoulliRNG(float probability_) : m_dist(0, 1) {
  277. megdnn_assert(0.0f <= probability_ && probability_ < 1.0f);
  278. m_probability = probability_;
  279. }
  280. dt_float32 BernoulliRNG::gen_single_val() {
  281. auto&& gen = RandomState::generator();
  282. return m_dist(gen) < m_probability ? 1.0 : 0.0;
  283. }
  284. void NoReplacementRNG::gen(const TensorND& tensor) {
  285. auto offset = tensor.layout.span().low_elem;
  286. auto nr_elems = tensor.layout.span().dist_elem();
  287. #define cb(DType) \
  288. if (tensor.layout.dtype == DType()) { \
  289. using ctype = typename DTypeTrait<DType>::ctype; \
  290. std::set<ctype> values; \
  291. auto ptr = tensor.ptr<ctype>(); \
  292. for (size_t i = 0; i < nr_elems; ++i) { \
  293. ctype val; \
  294. do { \
  295. val = static_cast<ctype>(m_iid_rng->gen_single_val()); \
  296. } while (!values.insert(val).second); \
  297. ptr[offset + i] = val; \
  298. } \
  299. }
  300. MEGDNN_FOREACH_COMPUTING_DTYPE(cb);
  301. #undef cb
  302. }
  303. InvertibleMatrixRNG::InvertibleMatrixRNG()
  304. : m_rng{new RNGxorshf{RandomState::generator()}} {}
  305. InvertibleMatrixRNG::~InvertibleMatrixRNG() noexcept = default;
  306. template <typename ctype>
  307. void InvertibleMatrixRNG::do_gen(ctype* ptr, size_t batch, size_t n) {
  308. auto&& gen = *m_rng;
  309. std::vector<size_t> perm(n);
  310. for (size_t i = 0; i < n; ++i) {
  311. perm[i] = i;
  312. }
  313. for (size_t i = 0; i < batch; ++i, ptr += n * n) {
  314. for (size_t j = 0; j < n; ++j) {
  315. for (size_t k = 0; k < n; ++k) {
  316. ptr[j * n + k] =
  317. static_cast<ctype>(gen() / (RNGxorshf::max() + 1.0) * 2 - 0.5);
  318. }
  319. }
  320. for (size_t i = 0; i < n; ++i) {
  321. auto idx = gen() % (n - i) + i;
  322. ptr[i * n + perm[idx]] +=
  323. static_cast<ctype>(gen() / (RNGxorshf::max() + 1.0) + 3);
  324. std::swap(perm[idx], perm[i]);
  325. }
  326. }
  327. }
  328. void InvertibleMatrixRNG::gen(const TensorND& tensor) {
  329. #define cb(DType) \
  330. if (tensor.layout.dtype == DType()) { \
  331. using ctype = typename DTypeTrait<DType>::ctype; \
  332. auto ptr = tensor.ptr<ctype>(); \
  333. megdnn_assert( \
  334. tensor.layout.ndim >= 2 && tensor.layout.is_physical_contiguous()); \
  335. size_t batch = 1; \
  336. for (size_t i = 0; i < tensor.layout.ndim - 2; ++i) { \
  337. batch *= tensor.layout[i]; \
  338. } \
  339. size_t n = tensor.layout[tensor.layout.ndim - 1]; \
  340. megdnn_assert(n == tensor.layout[tensor.layout.ndim - 2]); \
  341. do_gen<ctype>(ptr, batch, n); \
  342. return; \
  343. }
  344. MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb)
  345. #undef cb
  346. }
  347. void ConsecutiveRNG::fill_fast_float32(dt_float32* dest, size_t size) {
  348. for (size_t i = 0; i < size; ++i)
  349. dest[i] = value_ + i * delta_;
  350. }
  351. TEST(RNG, NO_REPLACEMENT_RNG) {
  352. static const size_t N = 10, TIMES = 100;
  353. UniformIntRNG base_rng(0, N - 1);
  354. NoReplacementRNG rng(&base_rng);
  355. auto handle = create_cpu_handle(2, false);
  356. for (size_t t = 0; t < TIMES; ++t) {
  357. TensorLayout layout({N}, dtype::Float32());
  358. Tensor<> tensor(handle.get(), layout);
  359. rng.gen(tensor.tensornd());
  360. std::vector<float> vals;
  361. for (size_t i = 0; i < N; ++i)
  362. vals.push_back(tensor.ptr()[i]);
  363. std::sort(vals.begin(), vals.end());
  364. for (size_t i = 0; i < N; ++i)
  365. ASSERT_EQ(static_cast<float>(i), vals[i]);
  366. }
  367. }
  368. // vim: syntax=cpp.doxygen