You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

region_restricted_convolution.cpp 9.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. #include "test/naive/fixture.h"
  2. #include "megdnn/oprs/nn.h"
  3. #include "test/common/checker.h"
  4. #include "test/common/convolution.h"
  5. // #include "test/common/regin_restricted_convolution.h"
  6. #include "test/common/extra_impl_helper.h"
  7. #include "test/common/random_state.h"
  8. using namespace megdnn;
  9. using namespace test;
  10. namespace {
  11. template <typename rtype>
  12. void mask_tensor_kernel(
  13. const TensorND& in, TensorND& out, const TensorND& mask,
  14. const int32_t mask_val) {
  15. megdnn_assert(
  16. in.layout.ndim == out.layout.ndim && in.layout.ndim == 4 &&
  17. mask.layout.ndim == 3);
  18. megdnn_assert_eq_layout(in.layout, out.layout);
  19. megdnn_assert(
  20. mask.layout[0] == in.layout[0] && mask.layout[1] == in.layout[2] &&
  21. mask.layout[2] == in.layout[3]);
  22. rtype* mask_ptr = mask.compatible_ptr<rtype>();
  23. float* src_ptr = in.compatible_ptr<float>();
  24. float* dst_ptr = out.compatible_ptr<float>();
  25. for (size_t n = 0; n < in.layout[0]; ++n) {
  26. for (size_t c = 0; c < in.layout[1]; ++c) {
  27. for (size_t h = 0; h < in.layout[2]; ++h) {
  28. for (size_t w = 0; w < in.layout[3]; ++w) {
  29. size_t mask_off = n * mask.layout.stride[0] +
  30. h * mask.layout.stride[1] +
  31. w * mask.layout.stride[2];
  32. size_t src_dst_off =
  33. n * in.layout.stride[0] + c * in.layout.stride[1] +
  34. h * in.layout.stride[2] + w * in.layout.stride[3];
  35. if (mask_ptr[mask_off] == mask_val) {
  36. dst_ptr[src_dst_off] = src_ptr[src_dst_off];
  37. } else {
  38. dst_ptr[src_dst_off] = 0.;
  39. }
  40. }
  41. }
  42. }
  43. }
  44. }
  45. void mask_tensor(
  46. const TensorND& in, TensorND& out, const TensorND& mask,
  47. const int32_t mask_val) {
  48. if (mask.layout.dtype == dtype::Int32()) {
  49. mask_tensor_kernel<dt_int32>(in, out, mask, mask_val);
  50. } else if (mask.layout.dtype == dtype::Uint8()) {
  51. mask_tensor_kernel<dt_uint8>(in, out, mask, mask_val);
  52. }
  53. }
  54. } // namespace
  55. TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD) {
  56. Checker<RegionRestrictedConvolution> checker(handle());
  57. RegionRestrictedConvolution::Param param;
  58. constexpr int N = 3;
  59. UniformIntRNG rng{0, N - 1};
  60. auto extra_impl = [&, this](const TensorNDArray& tensors) {
  61. auto conv = handle()->create_operator<Convolution>();
  62. conv->param() = param;
  63. auto workspace_size = conv->get_workspace_in_bytes(
  64. tensors[0].layout, tensors[1].layout, tensors[4].layout, nullptr);
  65. dt_byte* workspace_ptr = static_cast<dt_byte*>(malloc(workspace_size));
  66. Workspace workspace{workspace_ptr, workspace_size};
  67. TensorND masked_src(
  68. malloc(tensors[0].layout.span().dist_byte()), tensors[0].layout);
  69. TensorNDArray dst_tensors;
  70. for (int i = 0; i < N; ++i) {
  71. dst_tensors.emplace_back(
  72. malloc(tensors[4].layout.span().dist_byte()), tensors[4].layout);
  73. }
  74. for (int i = 0; i < N; ++i) {
  75. mask_tensor(tensors[0], masked_src, tensors[2], i);
  76. conv->exec(masked_src, tensors[1], dst_tensors[i], nullptr, workspace);
  77. mask_tensor(dst_tensors[i], dst_tensors[i], tensors[3], i);
  78. }
  79. free(workspace_ptr);
  80. using Mode = ElemwiseForward::Param::Mode;
  81. auto add = handle()->create_operator<ElemwiseForward>();
  82. add->param().mode = Mode::ADD;
  83. add->exec({dst_tensors[0], dst_tensors[1]}, tensors[4]);
  84. for (int i = 2; i < N; ++i) {
  85. add->exec({dst_tensors[i], tensors[4]}, tensors[4]);
  86. }
  87. };
  88. checker.set_extra_opr_impl(extra_impl)
  89. .set_rng(2, &rng)
  90. .set_rng(3, &rng)
  91. .set_dtype(2, dtype::Int32())
  92. .set_dtype(3, dtype::Int32());
  93. checker.execs({{1, 8, 2, 2}, {4, 8, 1, 1}, {1, 2, 2}, {1, 2, 2}, {}})
  94. .execs({{20, 12, 30, 30}, {4, 12, 1, 1}, {20, 30, 30}, {20, 30, 30}, {}})
  95. .execs({{20, 8, 30, 30}, {4, 8, 3, 3}, {20, 30, 30}, {20, 28, 28}, {}});
  96. checker.set_dtype(2, dtype::Uint8()).set_dtype(3, dtype::Uint8());
  97. checker.execs({{1, 8, 2, 2}, {4, 8, 1, 1}, {1, 2, 2}, {1, 2, 2}, {}})
  98. .execs({{20, 12, 30, 30}, {4, 12, 1, 1}, {20, 30, 30}, {20, 30, 30}, {}})
  99. .execs({{20, 8, 30, 30}, {4, 8, 3, 3}, {20, 30, 30}, {20, 28, 28}, {}});
  100. param.sparse = Convolution::Param::Sparse::GROUP;
  101. checker.set_param(param)
  102. .execs({{20, 15, 30, 30}, {5, 4, 3, 3, 3}, {20, 30, 30}, {20, 28, 28}, {}})
  103. .execs({{20, 25, 30, 30},
  104. {25, 1, 1, 3, 3},
  105. {20, 30, 30},
  106. {20, 28, 28},
  107. {}});
  108. checker.set_dtype(2, dtype::Int32()).set_dtype(3, dtype::Int32());
  109. checker.execs({{20, 15, 30, 30}, {5, 4, 3, 3, 3}, {20, 30, 30}, {20, 28, 28}, {}})
  110. .execs({{20, 25, 30, 30},
  111. {25, 1, 1, 3, 3},
  112. {20, 30, 30},
  113. {20, 28, 28},
  114. {}});
  115. }
  116. TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD_DENSE_BRUTE) {
  117. Checker<RegionRestrictedConvolutionForward> checker(handle());
  118. RegionRestrictedConvolutionForward::Param param;
  119. checker.set_param(param).exect(
  120. Testcase{
  121. TensorValue( // src
  122. {1, 1, 4, 4}, dtype::Float32(),
  123. {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}),
  124. TensorValue( // filter
  125. {1, 1, 2, 2}, dtype::Float32(), {1, 1, 1, 1}),
  126. TensorValue( // rin
  127. {1, 4, 4}, dtype::Int32(),
  128. {1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1}),
  129. TensorValue( // rout
  130. {1, 3, 3}, dtype::Int32(), {0, 1, 1, 1, 0, 0, 1, 0, 1}),
  131. {}, // output
  132. },
  133. Testcase{
  134. {},
  135. {},
  136. {},
  137. {},
  138. TensorValue(
  139. {1, 1, 3, 3}, dtype::Float32(),
  140. {4, 14, 18, 5, 9, 0, 13, 9, 50})});
  141. }
  142. TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_BWD_DATA_DENSE_BRUTE) {
  143. Checker<RegionRestrictedConvolutionBackwardData> checker(handle());
  144. RegionRestrictedConvolutionBackwardData::Param param;
  145. checker.set_param(param).exect(
  146. Testcase{
  147. // filter
  148. TensorValue(
  149. {1, 1, 2, 2}, // shape
  150. dtype::Float32(), // dtype
  151. {1.f, 1.f, 1.f, 1.f}),
  152. // diff
  153. TensorValue(
  154. {1, 1, 3, 3}, dtype::Float32(),
  155. {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f}),
  156. // rin
  157. TensorValue(
  158. {1, 4, 4}, dtype::Int32(),
  159. {1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1}),
  160. // rout
  161. TensorValue({1, 3, 3}, dtype::Int32(), {0, 1, 1, 1, 0, 0, 1, 0, 1}),
  162. // grad
  163. {}},
  164. Testcase{// filter
  165. {},
  166. // diff
  167. {},
  168. // rin
  169. {},
  170. // rout
  171. {},
  172. // grad
  173. TensorValue(
  174. {1, 1, 4, 4}, dtype::Float32(),
  175. {0., 2., 5., 3., 1., 6., 5., 3., 0., 13., 9., 9., 0., 7.,
  176. 9., 9.})});
  177. }
  178. TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_BWD_DATA_GROUP_BRUTE) {
  179. Checker<RegionRestrictedConvolutionBackwardData> checker(handle());
  180. // params
  181. RegionRestrictedConvolutionBackwardData::Param param;
  182. param.sparse = RegionRestrictedConvolutionBackwardData::Param::Sparse::GROUP;
  183. param.mode = RegionRestrictedConvolutionBackwardData::Mode::CROSS_CORRELATION;
  184. param.compute_mode =
  185. RegionRestrictedConvolutionBackwardData::Param::ComputeMode::DEFAULT;
  186. param.pad_h = param.pad_w =
  187. 0; // forward param, naive backward data doesn't matter with deconv padding
  188. param.stride_h = param.stride_w = 1;
  189. // checker setting
  190. checker.set_param(param).exect(
  191. Testcase{// filter
  192. TensorValue(
  193. {2, 1, 1, 2, 2}, // shape
  194. dtype::Float32(), // dtype
  195. {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}),
  196. // diff
  197. TensorValue({1, 2, 1, 1}, dtype::Float32(), {1, 2}),
  198. // rin
  199. TensorValue({1, 2, 2}, dtype::Int32(), {1, 1, 1, 1}),
  200. // rout
  201. TensorValue({1, 1, 1}, dtype::Int32(), {1}),
  202. // grad
  203. {}},
  204. Testcase{// filter
  205. {},
  206. // diff
  207. {},
  208. // rin
  209. {},
  210. // rout
  211. {},
  212. // grad
  213. TensorValue(
  214. {1, 2, 2, 2}, dtype::Float32(),
  215. {1, 2, 3, 4, 10, 12, 14, 16})});
  216. }
  217. // vim: syntax=cpp.doxygen