You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_bias.cpp 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. /**
  2. * \file dnn/test/naive/conv_bias.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/oprs/nn.h"
  13. #include "test/common/checker.h"
  14. #include "test/common/workspace_wrapper.h"
  15. #include "test/naive/fixture.h"
  16. using namespace megdnn;
  17. using namespace test;
  18. namespace {
  19. class TensorWrapper {
  20. public:
  21. TensorWrapper(Handle* handle, TensorLayout layout) : m_handle(handle) {
  22. m_tensornd.raw_ptr = megdnn_malloc(m_handle, layout.span().dist_byte());
  23. m_tensornd.layout = layout;
  24. }
  25. ~TensorWrapper() { megdnn_free(m_handle, m_tensornd.raw_ptr); }
  26. TensorND tensornd() const { return m_tensornd; }
  27. private:
  28. Handle* m_handle;
  29. TensorND m_tensornd;
  30. };
  31. } // namespace
  32. TEST_F(NAIVE, CONV_BIAS_QUANTIZED8x8x32) {
  33. Checker<ConvBias> checker(handle(), /* check_dispatch */ false);
  34. ConvBias::Param param;
  35. param.format = ConvBias::Param::Format::NCHW;
  36. checker.set_param(param).exect(
  37. Testcase{TensorValue({1, 1, 4, 4}, dtype::QuantizedS8(0.1f),
  38. {90 - 128, 136 - 128, 85 - 128, 204 - 128,
  39. 48 - 128, 9 - 128, 226 - 128, 25 - 128,
  40. 118 - 128, 109 - 128, 87 - 128, 132 - 128,
  41. 104 - 128, 163 - 128, 25 - 128, 90 - 128}),
  42. TensorValue({3, 1, 3, 3}, dtype::QuantizedS8(0.2f),
  43. {153 - 124, 170 - 124, 102 - 124, 103 - 124,
  44. 23 - 124, 213 - 124, 116 - 124, 195 - 124,
  45. 191 - 124, 44 - 124, 50 - 124, 247 - 124,
  46. 172 - 124, 42 - 124, 32 - 124, 233 - 124,
  47. 163 - 124, 247 - 124, 120 - 124, 241 - 124,
  48. 209 - 124, 83 - 124, 201 - 124, 115 - 124,
  49. 32 - 124, 140 - 124, 147 - 124}),
  50. TensorValue({1, 3, 1, 1}, dtype::QuantizedS32(0.02f),
  51. {0, 0, 0}),
  52. TensorValue({1, 3, 2, 2}, dtype::QuantizedS32(0.3f),
  53. {1234, 0, 0, 0, 0, 0, 0, 0, 0, -234, 0, 0}),
  54. {}},
  55. Testcase{{},
  56. {},
  57. {},
  58. {},
  59. TensorValue({1, 3, 2, 2}, dtype::QuantizedS32(0.1f * 0.2f),
  60. {37127, -22475, -15694, -1920,
  61. -12813, 4440, 18190, -13195,
  62. -9659, 12423, -5558, -4969})});
  63. }
  64. TEST_F(NAIVE, CONV_BIAS_QUANTIZED4x4x32) {
  65. Checker<ConvBias> checker(handle(), false);
  66. using Param = ConvBiasForward::Param;
  67. Param param;
  68. param.format = Param::Format::NCHW8;
  69. checker.set_param(param);
  70. auto GenTensorValueQuint4 = [](const TensorShape& shape,
  71. dtype::Quantized4Asymm dtype,
  72. const std::vector<int>& values) {
  73. TensorND tensor;
  74. tensor.layout = {shape, dtype};
  75. tensor.raw_ptr =
  76. static_cast<dt_byte*>(malloc(tensor.layout.span().dist_byte()));
  77. uint8_t* ptr = static_cast<uint8_t*>(tensor.raw_ptr);
  78. megdnn_assert(values.size() == tensor.layout.span().dist_elem());
  79. for (size_t i = 0; i < tensor.layout.span().dist_elem(); i += 2) {
  80. int val0 = values[i], val1 = values[i + 1];
  81. ptr[i / 2] = val0 | (val1 << 4);
  82. }
  83. return tensor;
  84. };
  85. checker.set_param(param).exect(
  86. Testcase{
  87. GenTensorValueQuint4(
  88. {1, 1, 4, 4, 8},
  89. dtype::Quantized4Asymm(0.1f, uint8_t(8)),
  90. {0, 6, 14, 5, 11, 2, 9, 9, 2, 1, 2, 11, 5,
  91. 0, 4, 8, 12, 15, 7, 7, 11, 0, 4, 1, 14, 9,
  92. 2, 0, 1, 11, 7, 13, 6, 11, 14, 4, 14, 6, 4,
  93. 3, 4, 2, 8, 15, 10, 6, 7, 0, 11, 13, 3, 9,
  94. 5, 13, 0, 5, 4, 5, 10, 5, 5, 0, 3, 13, 5,
  95. 4, 14, 10, 8, 3, 15, 1, 13, 5, 8, 9, 13, 10,
  96. 15, 13, 9, 0, 1, 11, 15, 4, 12, 11, 4, 5, 2,
  97. 9, 10, 9, 3, 1, 15, 10, 0, 1, 4, 6, 11, 2,
  98. 4, 9, 14, 6, 12, 0, 10, 13, 9, 7, 14, 14, 3,
  99. 14, 14, 7, 2, 4, 1, 9, 4, 7, 15, 10}),
  100. GenTensorValueQuint4(
  101. {8, 1, 3, 3, 8},
  102. dtype::Quantized4Asymm(0.2f, uint8_t(7)),
  103. {6, 8, 3, 6, 1, 9, 7, 8, 10, 0, 4, 11, 0,
  104. 1, 9, 8, 3, 3, 0, 9, 3, 2, 2, 2, 10, 5,
  105. 8, 7, 12, 10, 1, 11, 3, 1, 9, 8, 2, 15, 5,
  106. 0, 14, 3, 8, 15, 14, 7, 15, 4, 3, 3, 11, 9,
  107. 8, 4, 7, 14, 4, 6, 10, 7, 5, 5, 2, 0, 5,
  108. 0, 1, 10, 13, 1, 7, 12, 9, 11, 12, 7, 3, 15,
  109. 1, 10, 7, 8, 9, 1, 6, 8, 7, 0, 4, 12, 12,
  110. 11, 4, 0, 14, 1, 6, 15, 15, 4, 1, 2, 10, 9,
  111. 6, 0, 13, 2, 5, 8, 11, 1, 1, 2, 4, 13, 3,
  112. 3, 12, 11, 6, 5, 8, 11, 13, 12, 0, 13, 9, 13,
  113. 12, 1, 7, 10, 6, 12, 8, 13, 11, 1, 3, 5, 0,
  114. 10, 4, 8, 15, 13, 9, 7, 2, 14, 9, 9, 10, 7,
  115. 13, 0, 9, 4, 7, 10, 15, 4, 10, 10, 9, 13, 8,
  116. 7, 10, 9, 13, 12, 14, 8, 3, 6, 4, 8, 5, 5,
  117. 6, 3, 6, 6, 10, 4, 3, 0, 12, 8, 7, 3, 14,
  118. 7, 3, 2, 3, 7, 7, 3, 0, 8, 11, 3, 14, 1,
  119. 13, 10, 5, 7, 9, 15, 8, 9, 1, 3, 11, 13, 13,
  120. 6, 0, 6, 0, 10, 0, 1, 4, 3, 11, 3, 7, 1,
  121. 7, 10, 7, 2, 13, 15, 12, 0, 2, 0, 6, 15, 9,
  122. 13, 2, 10, 2, 1, 13, 13, 7, 7, 2, 10, 1, 12,
  123. 9, 5, 2, 8, 11, 13, 12, 5, 3, 1, 9, 14, 12,
  124. 6, 12, 12, 3, 7, 0, 8, 1, 9, 12, 2, 10, 11,
  125. 5, 11, 10, 10, 13, 9, 3, 1, 4, 9, 6, 2, 15,
  126. 8, 12, 5, 14, 0, 8, 1, 3, 2, 14, 1, 6, 4,
  127. 4, 10, 9, 5, 15, 8, 2, 4, 3, 11, 6, 12, 6,
  128. 3, 14, 5, 11, 5, 9, 15, 8, 3, 5, 3, 11, 9,
  129. 5, 7, 14, 9, 0, 5, 11, 9, 14, 13, 2, 1, 10,
  130. 6, 6, 6, 15, 0, 7, 9, 12, 6, 6, 5, 0, 14,
  131. 15, 9, 10, 10, 13, 7, 12, 5, 13, 2, 7, 14, 7,
  132. 14, 13, 0, 12, 10, 7, 4, 12, 1, 8, 7, 8, 0,
  133. 11, 12, 12, 4, 7, 9, 15, 1, 15, 11, 7, 6, 9,
  134. 0, 10, 6, 7, 5, 11, 14, 13, 14, 6, 3, 0, 3,
  135. 6, 10, 3, 5, 0, 7, 6, 14, 15, 8, 4, 13, 11,
  136. 3, 1, 5, 6, 2, 14, 1, 15, 4, 4, 4, 8, 7,
  137. 13, 0, 8, 14, 10, 8, 14, 7, 11, 0, 2, 15, 13,
  138. 15, 0, 7, 8, 15, 6, 6, 4, 2, 4, 10, 13, 10,
  139. 6, 1, 10, 14, 13, 6, 9, 6, 8, 11, 10, 13, 2,
  140. 6, 10, 0, 1, 6, 15, 7, 6, 6, 13, 9, 2, 9,
  141. 0, 2, 15, 15, 14, 0, 2, 13, 15, 15, 0, 7, 10,
  142. 10, 13, 15, 6, 13, 8, 5, 4, 12, 9, 4, 14, 8,
  143. 6, 13, 15, 2, 8, 10, 11, 6, 11, 0, 15, 0, 1,
  144. 5, 1, 14, 13, 7, 2, 6, 3, 9, 7, 6, 15, 12,
  145. 14, 2, 10, 12, 8, 14, 5, 12, 13, 15, 10, 9, 7,
  146. 7, 13, 6, 11, 13, 9, 4, 8, 9, 2, 11, 13, 8,
  147. 1, 0, 14, 6}),
  148. TensorValue({1, 1, 1, 1, 8},
  149. dtype::QuantizedS32(0.1f * 0.2f),
  150. {0, 0, 0, 0, 0, 0, 0, 0}),
  151. TensorValue(
  152. {1, 1, 2, 2, 8}, dtype::QuantizedS32(0.3f),
  153. {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  154. 0, 0, 0, 0, 0, 0, -87, 0, 0, 0, 0, 0, 0, 0, 0, 0}),
  155. {}},
  156. Testcase{
  157. {},
  158. {},
  159. {},
  160. {},
  161. TensorValue(
  162. {1, 1, 2, 2, 8}, dtype::QuantizedS32(0.1f * 0.2f),
  163. {275, -232, 55, -123, 81, -55, -324, 64,
  164. -104, -391, 242, -2, -162, -150, -232, -160,
  165. -192, -72, -52, -154, 198, -48, -1073, -105,
  166. 103, -218, -22, 446, -81, 90, -152, -126}),
  167. });
  168. }
  169. TEST_F(NAIVE, CONV_BIAS_QUANTIZED8x8x32_NCHW32) {
  170. Checker<ConvBias> checker(handle(), /* check_dispatch */ false);
  171. ConvBias::Param param;
  172. param.format = ConvBias::Param::Format::NCHW32;
  173. size_t N = 2, IC = 32, IH = 4, IW = 4, OC = 32, PH = 1, PW = 1, SH = 1,
  174. SW = 1, FH = 3, FW = 3;
  175. auto&& conv_opr = handle()->create_operator<ConvBias>();
  176. conv_opr->param().format = ConvBias::Param::Format::NCHW4;
  177. conv_opr->param().pad_h = param.pad_h = PH;
  178. conv_opr->param().pad_w = param.pad_w = PW;
  179. conv_opr->param().stride_h = param.stride_h = SH;
  180. conv_opr->param().stride_w = param.stride_w = SW;
  181. size_t OH = infer_conv_shape(IH, FH, SH, PH);
  182. size_t OW = infer_conv_shape(IW, FW, SW, PW);
  183. auto i8_min = std::numeric_limits<int8_t>().min();
  184. auto i8_max = std::numeric_limits<int8_t>().max();
  185. UniformIntRNG int_rng{i8_min, i8_max};
  186. TensorLayout src_layout_4{{N, IC / 4, IH, IW, 4}, dtype::QuantizedS8(0.1f)};
  187. TensorWrapper src_ts_4{handle(), src_layout_4};
  188. int_rng.gen(src_ts_4.tensornd());
  189. TensorLayout filter_layout_4{{OC, IC / 4, FH, FW, 4},
  190. dtype::QuantizedS8(0.2f)};
  191. TensorWrapper filter_ts_4{handle(), filter_layout_4};
  192. int_rng.gen(filter_ts_4.tensornd());
  193. TensorLayout bias_layout_4{{1, OC / 4, 1, 1, 4},
  194. dtype::QuantizedS32(0.02f)};
  195. TensorWrapper bias_ts_4{handle(), bias_layout_4};
  196. int_rng.gen(bias_ts_4.tensornd());
  197. TensorLayout dst_layout_4{{N, OC / 4, OH, OW, 4}, dtype::QuantizedS8(0.2f)};
  198. TensorWrapper dst_ts_4{handle(), dst_layout_4};
  199. TensorLayout z_layout_4{dst_layout_4, dtype::QuantizedS8(0.3f)};
  200. TensorWrapper z_ts_4{handle(), z_layout_4};
  201. int_rng.gen(z_ts_4.tensornd());
  202. size_t ws_size = conv_opr->get_workspace_in_bytes(
  203. src_layout_4, filter_layout_4, bias_layout_4, z_layout_4,
  204. dst_layout_4);
  205. WorkspaceWrapper ws{handle(), ws_size};
  206. conv_opr->exec(src_ts_4.tensornd(), filter_ts_4.tensornd(),
  207. bias_ts_4.tensornd(), z_ts_4.tensornd(), dst_ts_4.tensornd(),
  208. ws.workspace());
  209. TensorLayout src_layout_32{{N, IC / 32, IH, IW, 32},
  210. dtype::QuantizedS8(0.1f)};
  211. TensorWrapper src_ts_32{handle(), src_layout_32};
  212. TensorLayout filter_layout_32{{OC, IC / 32, FH, FW, 32},
  213. dtype::QuantizedS8(0.2f)};
  214. TensorWrapper filter_ts_32{handle(), filter_layout_32};
  215. TensorLayout bias_layout_32{{1, OC / 32, 1, 1, 32},
  216. dtype::QuantizedS32(0.02f)};
  217. TensorWrapper bias_ts_32{handle(), bias_layout_32};
  218. TensorLayout dst_layout_32{{N, OC / 32, OH, OW, 32},
  219. dtype::QuantizedS8(0.2f)};
  220. TensorWrapper dst_ts_32{handle(), dst_layout_32};
  221. TensorLayout z_layout_32{dst_layout_32, dtype::QuantizedS8(0.3f)};
  222. TensorWrapper z_ts_32{handle(), z_layout_32};
  223. auto from_nchw4_to_nchw32 = [](const TensorND in, const TensorND out) {
  224. size_t n = out.layout[0], c = out.layout[1], h = out.layout[2],
  225. w = out.layout[3];
  226. if (in.layout.dtype.enumv() == DTypeEnum::QuantizedS8) {
  227. int8_t* in_ptr = in.compatible_ptr<int8_t>();
  228. int8_t* out_ptr = out.compatible_ptr<int8_t>();
  229. for (size_t b = 0; b < n; b++) {
  230. for (size_t ch_out = 0; ch_out < c; ch_out++) {
  231. for (size_t h_ = 0; h_ < h; h_++) {
  232. for (size_t w_ = 0; w_ < w; w_++) {
  233. for (size_t ch_in = 0; ch_in < 32; ch_in++) {
  234. size_t ch = ch_out * 32 + ch_in;
  235. size_t ch_out_ = ch / 4;
  236. size_t ch_in_ = ch % 4;
  237. *out_ptr = in_ptr[b * c * h * w * 32 +
  238. ch_out_ * h * w * 4 +
  239. h_ * w * 4 + w_ * 4 + ch_in_];
  240. out_ptr++;
  241. }
  242. }
  243. }
  244. }
  245. }
  246. }
  247. if (in.layout.dtype.enumv() == DTypeEnum::QuantizedS32) {
  248. int32_t* in_ptr = in.compatible_ptr<int32_t>();
  249. int32_t* out_ptr = out.compatible_ptr<int32_t>();
  250. for (size_t b = 0; b < n; b++) {
  251. for (size_t ch_out = 0; ch_out < c; ch_out++) {
  252. for (size_t h_ = 0; h_ < h; h_++) {
  253. for (size_t w_ = 0; w_ < w; w_++) {
  254. for (size_t ch_in = 0; ch_in < 32; ch_in++) {
  255. size_t ch = ch_out * 32 + ch_in;
  256. size_t ch_out_ = ch / 4;
  257. size_t ch_in_ = ch % 4;
  258. *out_ptr = in_ptr[b * c * h * w * 32 +
  259. ch_out_ * h * w * 4 +
  260. h_ * w * 4 + w_ * 4 + ch_in_];
  261. out_ptr++;
  262. }
  263. }
  264. }
  265. }
  266. }
  267. }
  268. };
  269. from_nchw4_to_nchw32(src_ts_4.tensornd(), src_ts_32.tensornd());
  270. from_nchw4_to_nchw32(filter_ts_4.tensornd(), filter_ts_32.tensornd());
  271. from_nchw4_to_nchw32(bias_ts_4.tensornd(), bias_ts_32.tensornd());
  272. from_nchw4_to_nchw32(dst_ts_4.tensornd(), dst_ts_32.tensornd());
  273. from_nchw4_to_nchw32(z_ts_4.tensornd(), z_ts_32.tensornd());
  274. checker.set_param(param).exect(
  275. TensorNDArray{src_ts_32.tensornd(),
  276. filter_ts_32.tensornd(),
  277. bias_ts_32.tensornd(),
  278. z_ts_32.tensornd(),
  279. {}},
  280. TensorNDArray{{}, {}, {}, {}, dst_ts_32.tensornd()});
  281. }
  282. TEST_F(NAIVE, CONV_BIAS_NCHW44) {
  283. Checker<ConvBias> checker(handle(), /* check_dispatch */ false);
  284. ConvBias::Param param;
  285. param.format = ConvBias::Param::Format::NCHW44;
  286. size_t n = 1;
  287. size_t ic = 4;
  288. size_t oc = 8;
  289. size_t h = 2;
  290. size_t w = 2;
  291. size_t filter_size = 3;
  292. size_t pad = 1;
  293. auto src_tensor_shape = TensorShape{n, ic / 4, h, w, 4};
  294. auto weight_tensor_shape =
  295. TensorShape{oc / 4, ic / 4, filter_size, filter_size, 4, 4};
  296. auto bias_tensor_shape = TensorShape{1, oc / 4, 1, 1, 4};
  297. param.pad_h = pad;
  298. param.pad_w = pad;
  299. UniformIntRNG rng{-127, 127};
  300. checker.set_dtype(0, dtype::Float32())
  301. .set_dtype(1, dtype::Float32())
  302. .set_dtype(2, dtype::Float32())
  303. .set_dtype(4, dtype::Float32())
  304. .set_rng(0, &rng)
  305. .set_rng(1, &rng)
  306. .set_rng(2, &rng)
  307. .set_epsilon(1e-3)
  308. .set_param(param)
  309. .execs({src_tensor_shape,
  310. weight_tensor_shape,
  311. bias_tensor_shape,
  312. {},
  313. {}});
  314. checker.set_dtype(0, dtype::QuantizedS8(2.f))
  315. .set_dtype(1, dtype::QuantizedS8(3.f))
  316. .set_dtype(2, dtype::QuantizedS32(6.f))
  317. .set_dtype(4, dtype::QuantizedS32(6.f))
  318. .set_rng(0, &rng)
  319. .set_rng(1, &rng)
  320. .set_rng(2, &rng)
  321. .set_epsilon(1e-3)
  322. .set_param(param)
  323. .execs({src_tensor_shape,
  324. weight_tensor_shape,
  325. bias_tensor_shape,
  326. {},
  327. {}});
  328. {
  329. // test normal conv
  330. ConvBias::Param param;
  331. param.format = ConvBias::Param::Format::NCHW44;
  332. param.sparse = ConvBias::Param::Sparse::DENSE;
  333. param.pad_h = 1;
  334. param.pad_w = 1;
  335. checker.set_param(param).exect(
  336. Testcase{TensorValue({1, 1, 2, 2, 4}, dtype::Float32(),
  337. {7, 2, 2, 1, 7, 5, 6, 3, 1, 2, 8, 3, 7, 7,
  338. 6, 4}),
  339. TensorValue(
  340. {1, 1, 3, 3, 4, 4}, dtype::Float32(),
  341. {3, 5, 5, 2, 0, 1, 4, 8, 3, 5, 0, 7, 1, 7, 0,
  342. 7, 6, 4, 7, 7, 5, 2, 2, 4, 7, 6, 6, 3, 3, 2,
  343. 2, 8, 5, 0, 4, 4, 0, 5, 1, 0, 0, 4, 8, 4, 7,
  344. 7, 2, 0, 4, 8, 7, 3, 6, 2, 3, 0, 0, 6, 4, 4,
  345. 1, 4, 3, 8, 8, 8, 7, 2, 2, 5, 5, 1, 3, 2, 8,
  346. 1, 7, 0, 2, 7, 1, 6, 1, 5, 0, 6, 3, 0, 2, 4,
  347. 1, 1, 4, 2, 7, 5, 7, 8, 4, 5, 5, 7, 0, 3, 3,
  348. 2, 8, 6, 0, 1, 4, 6, 6, 6, 0, 1, 2, 4, 4, 1,
  349. 1, 7, 8, 2, 5, 2, 8, 3, 8, 3, 5, 0, 6, 3, 4,
  350. 3, 3, 7, 2, 8, 1, 1, 1, 4}),
  351. TensorValue({1, 1, 1, 1, 4}, dtype::Float32(),
  352. {7, 2, 8, 1}),
  353. TensorValue({1, 1, 2, 2, 4}, dtype::Float32(),
  354. {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  355. 0, 0}),
  356. {}},
  357. Testcase{
  358. {},
  359. {},
  360. {},
  361. {},
  362. TensorValue({1, 1, 2, 2, 4}, dtype::Float32(),
  363. {264, 338, 309, 195, 276, 332, 390, 199,
  364. 224, 268, 311, 218, 288, 311, 346, 277})});
  365. }
  366. {
  367. // test dw conv
  368. ConvBias::Param param;
  369. param.format = ConvBias::Param::Format::NCHW44;
  370. param.sparse = ConvBias::Param::Sparse::GROUP;
  371. param.pad_h = 1;
  372. param.pad_w = 1;
  373. checker.set_param(param).exect(
  374. Testcase{TensorValue({1, 1, 2, 2, 4}, dtype::Float32(),
  375. {5, 8, 3, 2, 4, 6, 1, 5, 0, 8, 2, 6, 8, 6,
  376. 5, 7}),
  377. TensorValue({1, 1, 1, 3, 3, 4}, dtype::Float32(),
  378. {3, 0, 3, 1, 6, 5, 7, 3, 5, 0, 0, 7,
  379. 4, 6, 0, 1, 8, 2, 3, 7, 1, 0, 2, 4,
  380. 7, 5, 3, 0, 6, 2, 1, 5, 8, 6, 3, 1}),
  381. TensorValue({1, 1, 1, 1, 4}, dtype::Float32(),
  382. {4, 3, 5, 6}),
  383. TensorValue({1, 1, 2, 2, 4}, dtype::Float32(),
  384. {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  385. 0, 0}),
  386. {}},
  387. Testcase{{},
  388. {},
  389. {},
  390. {},
  391. TensorValue({1, 1, 2, 2, 4}, dtype::Float32(),
  392. {112, 71, 33, 77, 104, 115, 19, 78, 62, 59,
  393. 42, 117, 107, 93, 36, 78})});
  394. }
  395. {
  396. // test group conv
  397. ConvBias::Param param;
  398. param.format = ConvBias::Param::Format::NCHW44;
  399. param.sparse = ConvBias::Param::Sparse::GROUP;
  400. param.pad_h = 1;
  401. param.pad_w = 1;
  402. checker.set_param(param).exect(
  403. Testcase{TensorValue({1, 2, 2, 2, 4}, dtype::Float32(),
  404. {6, 3, 2, 7, 7, 6, 4, 5, 8, 6, 3,
  405. 1, 1, 2, 8, 3, 1, 0, 6, 1, 3, 3,
  406. 6, 0, 0, 5, 6, 7, 2, 2, 4, 4}),
  407. TensorValue(
  408. {2, 1, 1, 3, 3, 4, 4}, dtype::Float32(),
  409. {3, 5, 5, 2, 0, 1, 4, 8, 3, 5, 0, 7, 1, 7, 0,
  410. 7, 6, 4, 7, 7, 5, 2, 2, 4, 7, 6, 6, 3, 3, 2,
  411. 2, 8, 5, 0, 4, 4, 0, 5, 1, 0, 0, 4, 8, 4, 7,
  412. 7, 2, 0, 4, 8, 7, 3, 6, 2, 3, 0, 0, 6, 4, 4,
  413. 1, 4, 3, 8, 8, 8, 7, 2, 2, 5, 5, 1, 3, 2, 8,
  414. 1, 7, 0, 2, 7, 1, 6, 1, 5, 0, 6, 3, 0, 2, 4,
  415. 1, 1, 4, 2, 7, 5, 7, 8, 4, 5, 5, 7, 0, 3, 3,
  416. 2, 8, 6, 0, 1, 4, 6, 6, 6, 0, 1, 2, 4, 4, 1,
  417. 1, 7, 8, 2, 5, 2, 8, 3, 8, 3, 5, 0, 6, 3, 4,
  418. 3, 3, 7, 2, 8, 1, 1, 1, 4, 7, 4, 5, 0, 6, 8,
  419. 7, 4, 8, 1, 3, 5, 3, 0, 0, 3, 7, 7, 7, 3, 8,
  420. 1, 2, 0, 1, 1, 2, 1, 3, 0, 0, 1, 1, 3, 0, 5,
  421. 6, 3, 0, 5, 4, 1, 4, 7, 0, 2, 1, 6, 7, 8, 0,
  422. 2, 1, 6, 7, 6, 3, 2, 7, 6, 5, 1, 1, 1, 2, 4,
  423. 6, 3, 3, 8, 0, 7, 1, 3, 7, 3, 2, 2, 4, 3, 5,
  424. 5, 6, 3, 3, 1, 2, 3, 0, 4, 0, 3, 3, 5, 5, 5,
  425. 2, 3, 1, 5, 4, 5, 8, 1, 7, 2, 1, 0, 1, 8, 2,
  426. 6, 7, 8, 4, 4, 7, 8, 4, 5, 8, 1, 1, 0, 7, 8,
  427. 4, 2, 2, 8, 6, 5, 2, 4, 8, 4, 0, 4, 0, 2, 1,
  428. 7, 1, 6}),
  429. TensorValue({1, 2, 1, 1, 4}, dtype::Float32(),
  430. {1, 8, 5, 6, 2, 8, 7, 7}),
  431. TensorValue({1, 2, 2, 2, 4}, dtype::Float32(),
  432. {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  433. 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  434. 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}),
  435. {}},
  436. Testcase{
  437. {},
  438. {},
  439. {},
  440. {},
  441. TensorValue({1, 2, 2, 2, 4}, dtype::Float32(),
  442. {260, 342, 244, 241, 293, 385, 362, 257,
  443. 278, 301, 303, 226, 273, 306, 318, 307,
  444. 180, 244, 169, 156, 210, 244, 206, 167,
  445. 126, 165, 156, 207, 191, 141, 209, 172})});
  446. }
  447. {
  448. // test normal conv
  449. ConvBias::Param param;
  450. param.format = ConvBias::Param::Format::NCHW44;
  451. param.sparse = ConvBias::Param::Sparse::DENSE;
  452. param.pad_h = 1;
  453. param.pad_w = 1;
  454. checker.set_param(param).exect(
  455. Testcase{TensorValue({1, 1, 2, 2, 4}, dtype::Int8(),
  456. {7, 2, 2, 1, 7, 5, 6, 3, 1, 2, 8, 3, 7, 7,
  457. 6, 4}),
  458. TensorValue(
  459. {1, 1, 3, 3, 4, 4}, dtype::Int8(),
  460. {3, 5, 5, 2, 0, 1, 4, 8, 3, 5, 0, 7, 1, 7, 0,
  461. 7, 6, 4, 7, 7, 5, 2, 2, 4, 7, 6, 6, 3, 3, 2,
  462. 2, 8, 5, 0, 4, 4, 0, 5, 1, 0, 0, 4, 8, 4, 7,
  463. 7, 2, 0, 4, 8, 7, 3, 6, 2, 3, 0, 0, 6, 4, 4,
  464. 1, 4, 3, 8, 8, 8, 7, 2, 2, 5, 5, 1, 3, 2, 8,
  465. 1, 7, 0, 2, 7, 1, 6, 1, 5, 0, 6, 3, 0, 2, 4,
  466. 1, 1, 4, 2, 7, 5, 7, 8, 4, 5, 5, 7, 0, 3, 3,
  467. 2, 8, 6, 0, 1, 4, 6, 6, 6, 0, 1, 2, 4, 4, 1,
  468. 1, 7, 8, 2, 5, 2, 8, 3, 8, 3, 5, 0, 6, 3, 4,
  469. 3, 3, 7, 2, 8, 1, 1, 1, 4}),
  470. TensorValue({1, 1, 1, 1, 4}, dtype::Int32(),
  471. {7, 2, 8, 1}),
  472. TensorValue({1, 1, 2, 2, 4}, dtype::Int32(),
  473. {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  474. 0, 0}),
  475. {}},
  476. Testcase{
  477. {},
  478. {},
  479. {},
  480. {},
  481. TensorValue({1, 1, 2, 2, 4}, dtype::Int32(),
  482. {264, 338, 309, 195, 276, 332, 390, 199,
  483. 224, 268, 311, 218, 288, 311, 346, 277})});
  484. }
  485. }
  486. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)