You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dct_ref.cpp 8.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. #include "test/common/dct_ref.h"
  2. namespace megdnn {
  3. namespace test {
  4. struct FixCase {
  5. std::vector<int> mask_offset;
  6. std::vector<int> mask_val;
  7. };
  8. using Param = DctChannelSelectForward::Param;
  9. static inline FixCase get_fix_mask(Param::FastImpl impl) {
  10. std::vector<int> fix_32_mask_offset{0, 16, 24, 32};
  11. std::vector<int> fix_32_mask_val{0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32,
  12. 25, 18, 11, 4, 5, 0, 1, 8, 16, 9, 2,
  13. 3, 10, 0, 1, 8, 16, 9, 2, 3, 10};
  14. megdnn_assert(impl == Param::FastImpl::FIX_32_MASK, "only support gen FIX_32_MASK");
  15. return {fix_32_mask_offset, fix_32_mask_val};
  16. }
  17. CheckerHelper::TensorsConstriant gen_dct_constriant(
  18. const size_t /* n */, const size_t ic, const size_t ih, const size_t iw,
  19. const size_t oc, Param param) {
  20. auto constraint = [=](CheckerHelper::TensorValueArray& tensors_orig) {
  21. const size_t block = param.dct_block_size;
  22. const int block_c = param.format == Param::Format::NCHW4 ? 4 : 1;
  23. megdnn_assert(oc % block_c == 0, "oc mod block_c must == 0");
  24. std::shared_ptr<DctTestcase> test_case_ptr = DctTestcase::make();
  25. DctTestcase& test_case = *test_case_ptr.get();
  26. UniformIntRNG rng(0, 255);
  27. UniformIntRNG mask_rng(0, 64 / block_c - 1);
  28. const size_t no_mask_oc = ic * block * block;
  29. megdnn_assert(ih % block == 0, "%zu mod %zu == 0", ih, block);
  30. megdnn_assert(iw % block == 0, "%zu mod %zu == 0", iw, block);
  31. TensorND mask_offset;
  32. TensorND mask_val;
  33. std::vector<int>& mask_offset_vec = test_case.mask_offset_vec;
  34. std::vector<int>& mask_val_vec = test_case.mask_val_vec;
  35. UniformIntRNG rng_oc(0, oc);
  36. if (param.fastImpl == Param::FastImpl::FIX_32_MASK) {
  37. auto fix_32_mask = get_fix_mask(Param::FastImpl::FIX_32_MASK);
  38. mask_offset_vec = fix_32_mask.mask_offset;
  39. mask_val_vec = fix_32_mask.mask_val;
  40. megdnn_assert(oc == 32, "oc must eq 32");
  41. } else if (no_mask_oc > oc) {
  42. size_t remain_oc = oc;
  43. mask_offset_vec.resize(ic + 1);
  44. mask_val_vec.resize(oc);
  45. mask_offset_vec[0] = 0;
  46. for (size_t ic_idx = 0; ic_idx < ic; ++ic_idx) {
  47. size_t random_len = (int)rng_oc.gen_single_val() * block_c;
  48. size_t mask_len = (ic_idx == ic - 1) || (remain_oc == 0)
  49. ? remain_oc
  50. : random_len % remain_oc;
  51. megdnn_assert(
  52. mask_len % block_c == 0,
  53. "mask_len mod block_c == 0, but %zu mod %d ", mask_len,
  54. block_c);
  55. const size_t oc_idx = mask_offset_vec[ic_idx];
  56. remain_oc -= mask_len;
  57. mask_offset_vec[ic_idx + 1] = oc_idx + mask_len;
  58. for (size_t mask_idx = 0; mask_idx < mask_len; ++mask_idx) {
  59. mask_val_vec[oc_idx + mask_idx] = (int)mask_rng.gen_single_val();
  60. }
  61. }
  62. }
  63. mask_offset = TensorND(
  64. mask_offset_vec.data(), {{mask_offset_vec.size()}, dtype::Int32()});
  65. mask_val =
  66. TensorND(mask_val_vec.data(), {{mask_val_vec.size()}, dtype::Int32()});
  67. if (tensors_orig.size() > 1) {
  68. megdnn_assert(tensors_orig.size() == 4, "tensors_orig.size() == 4");
  69. megdnn_assert(mask_offset_vec.size() >= 2, "mask_offset_vec.size() >= 2");
  70. megdnn_assert(
  71. tensors_orig[1].layout == mask_offset.layout,
  72. "tensors_orig[1].layout == mask_offset.layout");
  73. megdnn_assert(
  74. tensors_orig[2].layout == mask_val.layout,
  75. "tensors_orig[2].layout == mask_val.layout");
  76. auto naive_handle = create_cpu_handle(2, false);
  77. megdnn_memcpy_D2D(
  78. naive_handle.get(), tensors_orig[1].raw_ptr(),
  79. mask_offset.raw_ptr(), mask_offset.layout.span().dist_byte());
  80. megdnn_memcpy_D2D(
  81. naive_handle.get(), tensors_orig[2].raw_ptr(), mask_val.raw_ptr(),
  82. mask_val.layout.span().dist_byte());
  83. }
  84. };
  85. return constraint;
  86. }
  87. std::shared_ptr<DctTestcase> gen_dct_case(
  88. const size_t n, const size_t ic, const size_t ih, const size_t iw,
  89. const size_t oc, Param param, DType dst_dtype, bool correct_result) {
  90. const size_t block = param.dct_block_size;
  91. const int block_c = param.format == Param::Format::NCHW4 ? 4 : 1;
  92. megdnn_assert(oc % block_c == 0, "oc mod block_c must == 0");
  93. std::shared_ptr<DctTestcase> test_case_ptr = DctTestcase::make();
  94. DctTestcase& test_case = *test_case_ptr.get();
  95. UniformIntRNG rng(0, 255);
  96. UniformIntRNG mask_rng(0, 64 / block_c - 1);
  97. const size_t input_elements = n * ic * ih * iw;
  98. const size_t no_mask_oc = ic * block * block;
  99. megdnn_assert(ih % block == 0, "%zu mod %zu == 0", ih, block);
  100. megdnn_assert(iw % block == 0, "%zu mod %zu == 0", iw, block);
  101. std::vector<uint8_t>& inp_vec = test_case.inp_vec;
  102. inp_vec.resize(input_elements);
  103. TensorShape input_shape{n, ic, ih, iw};
  104. for (auto& elm : inp_vec) {
  105. elm = (uint8_t)rng.gen_single_val();
  106. }
  107. auto src = TensorND(inp_vec.data(), {input_shape, dtype::Uint8()});
  108. TensorND mask_offset;
  109. TensorND mask_val;
  110. std::vector<int>& mask_offset_vec = test_case.mask_offset_vec;
  111. std::vector<int>& mask_val_vec = test_case.mask_val_vec;
  112. UniformIntRNG rng_oc(0, oc);
  113. if (param.fastImpl == Param::FastImpl::FIX_32_MASK) {
  114. auto fix_32_mask = get_fix_mask(Param::FastImpl::FIX_32_MASK);
  115. mask_offset_vec = fix_32_mask.mask_offset;
  116. mask_val_vec = fix_32_mask.mask_val;
  117. megdnn_assert(oc == 32, "oc must eq 32");
  118. } else if (no_mask_oc > oc) {
  119. size_t remain_oc = oc;
  120. mask_offset_vec.resize(ic + 1);
  121. mask_val_vec.resize(oc);
  122. mask_offset_vec[0] = 0;
  123. for (size_t ic_idx = 0; ic_idx < ic; ++ic_idx) {
  124. size_t random_len = (int)rng_oc.gen_single_val() * block_c;
  125. size_t mask_len = (ic_idx == ic - 1) || (remain_oc == 0)
  126. ? remain_oc
  127. : random_len % remain_oc;
  128. megdnn_assert(
  129. mask_len % block_c == 0,
  130. "mask_len mod block_c == 0, but %zu mod %d ", mask_len, block_c);
  131. const size_t oc_idx = mask_offset_vec[ic_idx];
  132. remain_oc -= mask_len;
  133. mask_offset_vec[ic_idx + 1] = oc_idx + mask_len;
  134. for (size_t mask_idx = 0; mask_idx < mask_len; ++mask_idx) {
  135. mask_val_vec[oc_idx + mask_idx] = (int)mask_rng.gen_single_val();
  136. }
  137. }
  138. }
  139. mask_offset = TensorND(
  140. mask_offset_vec.data(), {{mask_offset_vec.size()}, dtype::Int32()});
  141. mask_val = TensorND(mask_val_vec.data(), {{mask_val_vec.size()}, dtype::Int32()});
  142. if (mask_offset_vec.size() >= 2) {
  143. test_case.testcase_in = {
  144. src, mask_offset, mask_val, {nullptr, {{}, dst_dtype}}};
  145. } else {
  146. test_case.testcase_in = {src, {}, {}, {nullptr, {{}, dst_dtype}}};
  147. }
  148. auto naive_handle = create_cpu_handle(2, false);
  149. auto opr_naive = naive_handle->create_operator<DctChannelSelectForward>();
  150. opr_naive->param() = param;
  151. using Proxy = OprProxy<DctChannelSelectForward>;
  152. Proxy naive_proxy;
  153. TensorLayout temp_dst_layout;
  154. temp_dst_layout.dtype = dst_dtype;
  155. TensorLayoutArray layouts{
  156. src.layout, mask_offset.layout, mask_val.layout, temp_dst_layout};
  157. naive_proxy.deduce_layout(opr_naive.get(), layouts);
  158. const size_t output_elements = layouts[3].total_nr_elems();
  159. std::vector<float>& output_vec = test_case.output_vec;
  160. output_vec.resize(output_elements);
  161. auto dst = TensorND(output_vec.data(), layouts[3]);
  162. DctTestcase::TensorValueArray testcase_naive;
  163. testcase_naive.emplace_back(test_case.testcase_in[0]);
  164. testcase_naive.emplace_back(test_case.testcase_in[1]);
  165. testcase_naive.emplace_back(test_case.testcase_in[2]);
  166. testcase_naive.emplace_back(dst);
  167. if (correct_result) {
  168. naive_proxy.exec(opr_naive.get(), testcase_naive);
  169. }
  170. test_case.testcase_out = {{}, {}, {}, dst};
  171. return test_case_ptr;
  172. }
  173. } // namespace test
  174. } // namespace megdnn
  175. // vim: syntax=cpp.doxygen