You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dct_ref.cpp 9.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. /**
  2. * \file
  3. * dnn/test/common/dct_ref.cpp
  4. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  5. *
  6. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  7. *
  8. * Unless required by applicable law or agreed to in writing,
  9. * software distributed under the License is distributed on an
  10. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  11. * implied.
  12. */
  13. #include "test/common/dct_ref.h"
  14. namespace megdnn {
  15. namespace test {
  16. struct FixCase {
  17. std::vector<int> mask_offset;
  18. std::vector<int> mask_val;
  19. };
  20. using Param = DctChannelSelectForward::Param;
  21. static inline FixCase get_fix_mask(Param::FastImpl impl) {
  22. std::vector<int> fix_32_mask_offset{0, 16, 24, 32};
  23. std::vector<int> fix_32_mask_val{0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32,
  24. 25, 18, 11, 4, 5, 0, 1, 8, 16, 9, 2,
  25. 3, 10, 0, 1, 8, 16, 9, 2, 3, 10};
  26. megdnn_assert(impl == Param::FastImpl::FIX_32_MASK, "only support gen FIX_32_MASK");
  27. return {fix_32_mask_offset, fix_32_mask_val};
  28. }
  29. CheckerHelper::TensorsConstriant gen_dct_constriant(
  30. const size_t /* n */, const size_t ic, const size_t ih, const size_t iw,
  31. const size_t oc, Param param) {
  32. auto constraint = [=](CheckerHelper::TensorValueArray& tensors_orig) {
  33. const size_t block = param.dct_block_size;
  34. const int block_c = param.format == Param::Format::NCHW4 ? 4 : 1;
  35. megdnn_assert(oc % block_c == 0, "oc mod block_c must == 0");
  36. std::shared_ptr<DctTestcase> test_case_ptr = DctTestcase::make();
  37. DctTestcase& test_case = *test_case_ptr.get();
  38. UniformIntRNG rng(0, 255);
  39. UniformIntRNG mask_rng(0, 64 / block_c - 1);
  40. const size_t no_mask_oc = ic * block * block;
  41. megdnn_assert(ih % block == 0, "%zu mod %zu == 0", ih, block);
  42. megdnn_assert(iw % block == 0, "%zu mod %zu == 0", iw, block);
  43. TensorND mask_offset;
  44. TensorND mask_val;
  45. std::vector<int>& mask_offset_vec = test_case.mask_offset_vec;
  46. std::vector<int>& mask_val_vec = test_case.mask_val_vec;
  47. UniformIntRNG rng_oc(0, oc);
  48. if (param.fastImpl == Param::FastImpl::FIX_32_MASK) {
  49. auto fix_32_mask = get_fix_mask(Param::FastImpl::FIX_32_MASK);
  50. mask_offset_vec = fix_32_mask.mask_offset;
  51. mask_val_vec = fix_32_mask.mask_val;
  52. megdnn_assert(oc == 32, "oc must eq 32");
  53. } else if (no_mask_oc > oc) {
  54. size_t remain_oc = oc;
  55. mask_offset_vec.resize(ic + 1);
  56. mask_val_vec.resize(oc);
  57. mask_offset_vec[0] = 0;
  58. for (size_t ic_idx = 0; ic_idx < ic; ++ic_idx) {
  59. size_t random_len = (int)rng_oc.gen_single_val() * block_c;
  60. size_t mask_len = (ic_idx == ic - 1) || (remain_oc == 0)
  61. ? remain_oc
  62. : random_len % remain_oc;
  63. megdnn_assert(
  64. mask_len % block_c == 0,
  65. "mask_len mod block_c == 0, but %zu mod %d ", mask_len,
  66. block_c);
  67. const size_t oc_idx = mask_offset_vec[ic_idx];
  68. remain_oc -= mask_len;
  69. mask_offset_vec[ic_idx + 1] = oc_idx + mask_len;
  70. for (size_t mask_idx = 0; mask_idx < mask_len; ++mask_idx) {
  71. mask_val_vec[oc_idx + mask_idx] = (int)mask_rng.gen_single_val();
  72. }
  73. }
  74. }
  75. mask_offset = TensorND(
  76. mask_offset_vec.data(), {{mask_offset_vec.size()}, dtype::Int32()});
  77. mask_val =
  78. TensorND(mask_val_vec.data(), {{mask_val_vec.size()}, dtype::Int32()});
  79. if (tensors_orig.size() > 1) {
  80. megdnn_assert(tensors_orig.size() == 4, "tensors_orig.size() == 4");
  81. megdnn_assert(mask_offset_vec.size() >= 2, "mask_offset_vec.size() >= 2");
  82. megdnn_assert(
  83. tensors_orig[1].layout == mask_offset.layout,
  84. "tensors_orig[1].layout == mask_offset.layout");
  85. megdnn_assert(
  86. tensors_orig[2].layout == mask_val.layout,
  87. "tensors_orig[2].layout == mask_val.layout");
  88. auto naive_handle = create_cpu_handle(2, false);
  89. megdnn_memcpy_D2D(
  90. naive_handle.get(), tensors_orig[1].raw_ptr(),
  91. mask_offset.raw_ptr(), mask_offset.layout.span().dist_byte());
  92. megdnn_memcpy_D2D(
  93. naive_handle.get(), tensors_orig[2].raw_ptr(), mask_val.raw_ptr(),
  94. mask_val.layout.span().dist_byte());
  95. }
  96. };
  97. return constraint;
  98. }
  99. std::shared_ptr<DctTestcase> gen_dct_case(
  100. const size_t n, const size_t ic, const size_t ih, const size_t iw,
  101. const size_t oc, Param param, DType dst_dtype, bool correct_result) {
  102. const size_t block = param.dct_block_size;
  103. const int block_c = param.format == Param::Format::NCHW4 ? 4 : 1;
  104. megdnn_assert(oc % block_c == 0, "oc mod block_c must == 0");
  105. std::shared_ptr<DctTestcase> test_case_ptr = DctTestcase::make();
  106. DctTestcase& test_case = *test_case_ptr.get();
  107. UniformIntRNG rng(0, 255);
  108. UniformIntRNG mask_rng(0, 64 / block_c - 1);
  109. const size_t input_elements = n * ic * ih * iw;
  110. const size_t no_mask_oc = ic * block * block;
  111. megdnn_assert(ih % block == 0, "%zu mod %zu == 0", ih, block);
  112. megdnn_assert(iw % block == 0, "%zu mod %zu == 0", iw, block);
  113. std::vector<uint8_t>& inp_vec = test_case.inp_vec;
  114. inp_vec.resize(input_elements);
  115. TensorShape input_shape{n, ic, ih, iw};
  116. for (auto& elm : inp_vec) {
  117. elm = (uint8_t)rng.gen_single_val();
  118. }
  119. auto src = TensorND(inp_vec.data(), {input_shape, dtype::Uint8()});
  120. TensorND mask_offset;
  121. TensorND mask_val;
  122. std::vector<int>& mask_offset_vec = test_case.mask_offset_vec;
  123. std::vector<int>& mask_val_vec = test_case.mask_val_vec;
  124. UniformIntRNG rng_oc(0, oc);
  125. if (param.fastImpl == Param::FastImpl::FIX_32_MASK) {
  126. auto fix_32_mask = get_fix_mask(Param::FastImpl::FIX_32_MASK);
  127. mask_offset_vec = fix_32_mask.mask_offset;
  128. mask_val_vec = fix_32_mask.mask_val;
  129. megdnn_assert(oc == 32, "oc must eq 32");
  130. } else if (no_mask_oc > oc) {
  131. size_t remain_oc = oc;
  132. mask_offset_vec.resize(ic + 1);
  133. mask_val_vec.resize(oc);
  134. mask_offset_vec[0] = 0;
  135. for (size_t ic_idx = 0; ic_idx < ic; ++ic_idx) {
  136. size_t random_len = (int)rng_oc.gen_single_val() * block_c;
  137. size_t mask_len = (ic_idx == ic - 1) || (remain_oc == 0)
  138. ? remain_oc
  139. : random_len % remain_oc;
  140. megdnn_assert(
  141. mask_len % block_c == 0,
  142. "mask_len mod block_c == 0, but %zu mod %d ", mask_len, block_c);
  143. const size_t oc_idx = mask_offset_vec[ic_idx];
  144. remain_oc -= mask_len;
  145. mask_offset_vec[ic_idx + 1] = oc_idx + mask_len;
  146. for (size_t mask_idx = 0; mask_idx < mask_len; ++mask_idx) {
  147. mask_val_vec[oc_idx + mask_idx] = (int)mask_rng.gen_single_val();
  148. }
  149. }
  150. }
  151. mask_offset = TensorND(
  152. mask_offset_vec.data(), {{mask_offset_vec.size()}, dtype::Int32()});
  153. mask_val = TensorND(mask_val_vec.data(), {{mask_val_vec.size()}, dtype::Int32()});
  154. if (mask_offset_vec.size() >= 2) {
  155. test_case.testcase_in = {
  156. src, mask_offset, mask_val, {nullptr, {{}, dst_dtype}}};
  157. } else {
  158. test_case.testcase_in = {src, {}, {}, {nullptr, {{}, dst_dtype}}};
  159. }
  160. auto naive_handle = create_cpu_handle(2, false);
  161. auto opr_naive = naive_handle->create_operator<DctChannelSelectForward>();
  162. opr_naive->param() = param;
  163. using Proxy = OprProxy<DctChannelSelectForward>;
  164. Proxy naive_proxy;
  165. TensorLayout temp_dst_layout;
  166. temp_dst_layout.dtype = dst_dtype;
  167. TensorLayoutArray layouts{
  168. src.layout, mask_offset.layout, mask_val.layout, temp_dst_layout};
  169. naive_proxy.deduce_layout(opr_naive.get(), layouts);
  170. const size_t output_elements = layouts[3].total_nr_elems();
  171. std::vector<float>& output_vec = test_case.output_vec;
  172. output_vec.resize(output_elements);
  173. auto dst = TensorND(output_vec.data(), layouts[3]);
  174. DctTestcase::TensorValueArray testcase_naive;
  175. testcase_naive.emplace_back(test_case.testcase_in[0]);
  176. testcase_naive.emplace_back(test_case.testcase_in[1]);
  177. testcase_naive.emplace_back(test_case.testcase_in[2]);
  178. testcase_naive.emplace_back(dst);
  179. if (correct_result) {
  180. naive_proxy.exec(opr_naive.get(), testcase_naive);
  181. }
  182. test_case.testcase_out = {{}, {}, {}, dst};
  183. return test_case_ptr;
  184. }
  185. } // namespace test
  186. } // namespace megdnn
  187. // vim: syntax=cpp.doxygen