You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convolution.cpp 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743
  1. #include "test/common/convolution.h"
  2. #include "src/common/algo_base.h"
  3. #include "test/common/checker.h"
  4. #include <sstream>
  5. #include <unordered_set>
  6. using namespace megdnn;
  7. using namespace test;
  8. using namespace convolution;
  9. std::vector<TestArg> convolution::get_1x1_args() {
  10. std::vector<TestArg> args;
  11. param::Convolution param;
  12. param.mode = param::Convolution::Mode::CROSS_CORRELATION;
  13. // clang-format off
  14. for (size_t batch_size: {1, 8})
  15. for (size_t ic: {1, 16})
  16. for (size_t oc: {1, 16})
  17. for (size_t ih : {8, 32}) {
  18. size_t iw = ih;
  19. args.emplace_back(param, TensorShape{batch_size, ic, ih, iw},
  20. TensorShape{oc, ic, 1, 1});
  21. }
  22. // clang-format on
  23. return args;
  24. }
  25. std::vector<TestArg> convolution::get_args_common() {
  26. std::vector<TestArg> args;
  27. for (size_t i = 16; i < 24; ++i) {
  28. param::Convolution param;
  29. param.mode = param::Convolution::Mode::CONVOLUTION;
  30. args.emplace_back(param, TensorShape{5, 2, i, i + 1}, TensorShape{3, 2, 3, 4});
  31. param.mode = param::Convolution::Mode::CROSS_CORRELATION;
  32. args.emplace_back(param, TensorShape{5, 2, i, i + 1}, TensorShape{3, 2, 3, 4});
  33. }
  34. return args;
  35. }
  36. std::vector<TestArg> convolution::get_args_padding() {
  37. std::vector<TestArg> args;
  38. for (size_t i = 16; i < 24; ++i) {
  39. param::Convolution param;
  40. param.pad_h = 1;
  41. param.pad_w = 2;
  42. param.mode = param::Convolution::Mode::CONVOLUTION;
  43. args.emplace_back(param, TensorShape{5, 2, i, i + 1}, TensorShape{3, 2, 3, 4});
  44. param.mode = param::Convolution::Mode::CROSS_CORRELATION;
  45. args.emplace_back(param, TensorShape{5, 2, i, i + 1}, TensorShape{3, 2, 3, 4});
  46. }
  47. return args;
  48. }
  49. std::vector<TestArg> convolution::get_args_large_channel() {
  50. std::vector<TestArg> args;
  51. for (size_t i = 16; i < 24; ++i) {
  52. param::Convolution param;
  53. param.mode = param::Convolution::Mode::CONVOLUTION;
  54. args.emplace_back(
  55. param, TensorShape{2, 20, i, i + 1}, TensorShape{30, 20, 3, 4});
  56. param.mode = param::Convolution::Mode::CROSS_CORRELATION;
  57. args.emplace_back(
  58. param, TensorShape{2, 20, i, i + 1}, TensorShape{30, 20, 3, 4});
  59. }
  60. for (size_t i = 16; i < 24; ++i) {
  61. param::Convolution param;
  62. param.pad_h = 1;
  63. param.pad_w = 2;
  64. param.mode = param::Convolution::Mode::CONVOLUTION;
  65. args.emplace_back(
  66. param, TensorShape{2, 20, i, i + 1}, TensorShape{30, 20, 3, 4});
  67. param.mode = param::Convolution::Mode::CROSS_CORRELATION;
  68. args.emplace_back(
  69. param, TensorShape{2, 20, i, i + 1}, TensorShape{30, 20, 3, 4});
  70. }
  71. return args;
  72. }
  73. std::vector<TestArg> convolution::get_args_1x1() {
  74. std::vector<TestArg> args;
  75. for (size_t i = 16; i < 24; ++i) {
  76. param::Convolution param;
  77. param.mode = param::Convolution::Mode::CONVOLUTION;
  78. args.emplace_back(
  79. param, TensorShape{2, 20, i, i + 1}, TensorShape{30, 20, 1, 1});
  80. param.mode = param::Convolution::Mode::CROSS_CORRELATION;
  81. args.emplace_back(
  82. param, TensorShape{2, 20, i, i + 1}, TensorShape{30, 20, 1, 1});
  83. }
  84. return args;
  85. }
  86. std::vector<TestArg> convolution::get_args_large_filter() {
  87. std::vector<TestArg> args;
  88. for (size_t i = 16; i < 24; ++i) {
  89. param::Convolution param;
  90. param.mode = param::Convolution::Mode::CONVOLUTION;
  91. args.emplace_back(param, TensorShape{2, 2, i, i + 1}, TensorShape{3, 2, 7, 8});
  92. param.mode = param::Convolution::Mode::CROSS_CORRELATION;
  93. args.emplace_back(param, TensorShape{2, 2, i, i + 1}, TensorShape{3, 2, 7, 8});
  94. }
  95. return args;
  96. }
  97. std::vector<TestArg> convolution::get_args_exhaustive_search() {
  98. std::vector<TestArg> args;
  99. // clang-format off
  100. for (size_t n: {1, 2})
  101. for (size_t ih: {11, 13})
  102. for (size_t iw: {ih+1})
  103. for (size_t ic: {3})
  104. for (size_t oc: {4})
  105. for (size_t fh: {3, 6})
  106. for (size_t fw: {fh+1})
  107. for (size_t ph: {0, 1})
  108. for (size_t sh: {1, 2})
  109. for (bool xcorr : {false, true}) {
  110. param::Convolution param;
  111. param.mode = xcorr ? param::Convolution::Mode::CROSS_CORRELATION
  112. : param::Convolution::Mode::CONVOLUTION;
  113. param.stride_h = param.stride_w = sh;
  114. param.pad_h = param.pad_w = ph;
  115. args.emplace_back(param, TensorShape{n, ic, ih, iw},
  116. TensorShape{oc, ic, fh, fw});
  117. }
  118. // clang-format on
  119. return args;
  120. }
  121. std::vector<TestArg> convolution::get_args_4x4() {
  122. std::vector<TestArg> args;
  123. for (size_t oh = 1; oh < 20; ++oh) {
  124. param::Convolution param;
  125. param.mode = param::Convolution::Mode::CROSS_CORRELATION;
  126. args.emplace_back(
  127. param, TensorShape{4, 3, oh + 3, oh + 4}, TensorShape{2, 3, 4, 4});
  128. }
  129. return args;
  130. }
  131. std::vector<TestArg> convolution::get_args_large_channels() {
  132. std::vector<TestArg> args;
  133. // clang-format off
  134. for (size_t n: {2})
  135. for (size_t ih: {13})
  136. for (size_t iw: {ih+1})
  137. for (size_t ic: {32})
  138. for (size_t oc: {32})
  139. for (size_t fh: {3, 6})
  140. for (size_t fw: {fh+1})
  141. for (size_t ph: {0, 1})
  142. for (size_t sh: {1, 2})
  143. for (bool xcorr : {false, true}) {
  144. param::Convolution param;
  145. param.mode = xcorr ? param::Convolution::Mode::CROSS_CORRELATION
  146. : param::Convolution::Mode::CONVOLUTION;
  147. param.stride_h = param.stride_w = sh;
  148. param.pad_h = param.pad_w = ph;
  149. args.emplace_back(param, TensorShape{n, ic, ih, iw},
  150. TensorShape{oc, ic, fh, fw});
  151. }
  152. // clang-format on
  153. return args;
  154. }
  155. std::vector<TestArg> convolution::get_args_x86_direct_case_2() {
  156. std::vector<TestArg> args;
  157. // clang-format off
  158. for (size_t stride: {1, 2})
  159. for (size_t ker_size : {3, 5, 7, 9}) {
  160. param::Convolution param;
  161. param.mode = param::Convolution::Mode::CROSS_CORRELATION;
  162. param.stride_h = param.stride_w = stride;
  163. param.pad_h = param.pad_w = ker_size / 2;
  164. args.emplace_back(param, TensorShape{2, 2, 100, 99},
  165. TensorShape{3, 2, ker_size, ker_size});
  166. args.emplace_back(param, TensorShape{2, 2, 100, 99},
  167. TensorShape{1, 2, ker_size, ker_size});
  168. }
  169. // clang-format on
  170. return args;
  171. }
  172. std::vector<TestArg> convolution::get_args_fallback_templated_impl() {
  173. std::vector<TestArg> args;
  174. // clang-format off
  175. for (size_t sh: {1, 2})
  176. for (size_t sw: {1, 2})
  177. for (size_t ph: {0, 1, 2})
  178. for (size_t pw: {0, 1, 2})
  179. for (size_t ker_size: {3, 4, 5, 7})
  180. for (bool xcorr : {false, true}) {
  181. param::Convolution param;
  182. param.mode = xcorr ? param::Convolution::Mode::CROSS_CORRELATION
  183. : param::Convolution::Mode::CONVOLUTION;
  184. param.stride_h = sh;
  185. param.stride_w = sw;
  186. param.pad_h = ph;
  187. param.pad_w = pw;
  188. args.emplace_back(param, TensorShape{2, 2, 50, 55},
  189. TensorShape{3, 2, ker_size, ker_size});
  190. args.emplace_back(param, TensorShape{2, 2, 50, 55},
  191. TensorShape{1, 2, ker_size, ker_size});
  192. }
  193. // clang-format on
  194. return args;
  195. }
  196. std::vector<TestArg> convolution::get_args_fallback_non_templated_impl() {
  197. std::vector<TestArg> args;
  198. // clang-format off
  199. for (size_t sh: {1, 2})
  200. for (size_t sw: {1, 2})
  201. for (size_t ph: {0, 1, 2})
  202. for (size_t pw: {0, 1, 2})
  203. for (size_t ker_size: {3, 4, 5, 7})
  204. for (bool xcorr : {false, true}) {
  205. param::Convolution param;
  206. param.mode = xcorr ? param::Convolution::Mode::CROSS_CORRELATION
  207. : param::Convolution::Mode::CONVOLUTION;
  208. param.stride_h = sh;
  209. param.stride_w = sw;
  210. param.pad_h = ph;
  211. param.pad_w = pw;
  212. args.emplace_back(param, TensorShape{2, 2, 10, 55},
  213. TensorShape{3, 2, ker_size, ker_size + 1});
  214. args.emplace_back(param, TensorShape{2, 2, 10, 55},
  215. TensorShape{1, 2, ker_size, ker_size + 1});
  216. }
  217. // clang-format on
  218. return args;
  219. }
  220. std::vector<TestArg> convolution::get_args_cudnn_5_1_failures() {
  221. std::vector<TestArg> args;
  222. args.emplace_back(
  223. param::Convolution{param::Convolution::Mode::CROSS_CORRELATION, 0, 4, 1, 2},
  224. TensorShape{5, 3, 25, 20}, TensorShape{10, 3, 7, 4});
  225. return args;
  226. }
  227. std::vector<TestArg> convolution::get_args_cudnn_5_1_backward() {
  228. std::vector<TestArg> args;
  229. args.emplace_back(
  230. param::Convolution{param::Convolution::Mode::CROSS_CORRELATION, 2, 2, 2, 2},
  231. TensorShape{2, 8, 18, 18}, TensorShape{8, 8, 2, 2});
  232. return args;
  233. }
  234. std::vector<TestArg> convolution::get_args_x86_winograd_algorithm() {
  235. std::vector<TestArg> args;
  236. for (size_t ic_size : {8, 16}) {
  237. param::Convolution param;
  238. param.mode = param::Convolution::Mode::CROSS_CORRELATION;
  239. param.stride_h = param.stride_w = 1;
  240. param.pad_h = param.pad_w = 0;
  241. args.emplace_back(
  242. param, TensorShape{2, ic_size, 102, 102},
  243. TensorShape{8, ic_size, 3, 3});
  244. }
  245. return args;
  246. }
  247. std::vector<TestArg> convolution::get_args_BRAIN_481() {
  248. std::vector<TestArg> args;
  249. {
  250. param::Convolution param{
  251. param::Convolution::Mode::CROSS_CORRELATION, 0, 1, 1, 2};
  252. args.emplace_back(param, TensorShape{4, 4, 14, 13}, TensorShape{3, 4, 8, 13});
  253. for (size_t margin = 0; margin < 5; ++margin) {
  254. param::Convolution param{
  255. param::Convolution::Mode::CROSS_CORRELATION, 1, 1, 2, 2};
  256. args.emplace_back(
  257. param, TensorShape{4, 4, 14, 13},
  258. TensorShape{3, 4, 16 - margin, 15 - margin});
  259. }
  260. }
  261. return args;
  262. }
  263. std::vector<TestArg> convolution::get_args() {
  264. std::vector<TestArg> all_args, args;
  265. #define ADD_ARGS(NAME) \
  266. args = get_args_##NAME(); \
  267. all_args.insert(all_args.end(), args.begin(), args.end());
  268. ADD_ARGS(common)
  269. ADD_ARGS(padding)
  270. ADD_ARGS(large_channel)
  271. ADD_ARGS(1x1)
  272. ADD_ARGS(large_filter)
  273. ADD_ARGS(exhaustive_search)
  274. ADD_ARGS(4x4)
  275. ADD_ARGS(large_channels)
  276. ADD_ARGS(x86_direct_case_2)
  277. ADD_ARGS(fallback_templated_impl)
  278. ADD_ARGS(fallback_non_templated_impl)
  279. ADD_ARGS(cudnn_5_1_failures)
  280. ADD_ARGS(x86_winograd_algorithm)
  281. ADD_ARGS(BRAIN_481)
  282. #undef ADD_ARGS
  283. return all_args;
  284. }
  285. std::vector<TestArg> convolution::get_args_cuda_conv_bwd_data() {
  286. std::vector<TestArg> all_args, args;
  287. #define ADD_ARGS(NAME) \
  288. args = get_args_##NAME(); \
  289. all_args.insert(all_args.end(), args.begin(), args.end());
  290. ADD_ARGS(common)
  291. ADD_ARGS(padding)
  292. ADD_ARGS(large_channel)
  293. ADD_ARGS(1x1)
  294. ADD_ARGS(large_filter)
  295. ADD_ARGS(exhaustive_search)
  296. ADD_ARGS(4x4)
  297. ADD_ARGS(large_channels)
  298. ADD_ARGS(x86_direct_case_2)
  299. ADD_ARGS(fallback_templated_impl)
  300. ADD_ARGS(fallback_non_templated_impl)
  301. ADD_ARGS(x86_winograd_algorithm)
  302. #undef ADD_ARGS
  303. return all_args;
  304. }
  305. std::vector<TestArg> convolution::get_args_cudnn_7_5_failures() {
  306. std::vector<TestArg> all_args, args;
  307. #define ADD_ARGS(NAME) \
  308. args = get_args_##NAME(); \
  309. all_args.insert(all_args.end(), args.begin(), args.end());
  310. ADD_ARGS(cudnn_5_1_failures)
  311. ADD_ARGS(BRAIN_481)
  312. #undef ADD_ARGS
  313. return all_args;
  314. }
  315. std::vector<TestArg> convolution::get_chanwise_args() {
  316. std::vector<TestArg> args;
  317. // clang-format off
  318. for (size_t n: {2})
  319. for (size_t ih: {13})
  320. for (size_t iw: {ih+1})
  321. for (size_t c: {4, 36, 128, 320})
  322. for (size_t fh: {3, 5})
  323. for (size_t fw: {fh+1})
  324. for (size_t ph: {0, 1})
  325. for (size_t sh: {1, 2})
  326. for (size_t dh : {1, 2}) {
  327. param::Convolution param;
  328. param.sparse = param::Convolution::Sparse::GROUP;
  329. param.stride_h = param.stride_w = sh;
  330. param.pad_h = param.pad_w = ph;
  331. param.dilate_h = param.dilate_w = dh;
  332. args.emplace_back(param, TensorShape{n, c, ih, iw},
  333. TensorShape{c, 1, 1, fh, fw});
  334. }
  335. // clang-format on
  336. return args;
  337. }
  338. std::vector<TestArg> convolution::get_dilated_args() {
  339. std::vector<TestArg> args;
  340. param::Convolution param;
  341. param.pad_h = param.pad_w = 2;
  342. param.dilate_h = param.dilate_w = 2;
  343. size_t n = 1, ic = 15, ih = 128, iw = 128, fh = 3, fw = 3, oc = 17;
  344. args.emplace_back(param, TensorShape{n, ic, ih, iw}, TensorShape{oc, ic, fh, fw});
  345. // exhaustive search
  346. // clang-format off
  347. for (size_t n: {2})
  348. for (size_t ih: {23})
  349. for (size_t iw: {ih+1})
  350. for (size_t ic: {3})
  351. for (size_t oc: {4})
  352. for (size_t fh: {3, 6})
  353. for (size_t fw: {fh+1})
  354. for (size_t ph: {0, 1})
  355. for (size_t sh: {2})
  356. for (size_t dh : {3}) {
  357. param::Convolution param;
  358. param.stride_h = param.stride_w = sh;
  359. param.pad_h = param.pad_w = ph;
  360. param.dilate_h = dh;
  361. param.dilate_w = 3;
  362. args.emplace_back(param, TensorShape{n, ic, ih, iw},
  363. TensorShape{oc, ic, fh, fw});
  364. }
  365. // clang-format on
  366. return args;
  367. }
  368. std::vector<TestArg> convolution::get_args_int8_nchw4_conv_bwd_data() {
  369. std::vector<TestArg> args;
  370. param::Convolution cur_param;
  371. // clang-format off
  372. for (auto mode : {param::Convolution::Mode::CROSS_CORRELATION}) {
  373. for (size_t b : {64, 16}) {
  374. for (size_t ic : {16, 32}) {
  375. for (size_t oc : {16, 32}) {
  376. for (size_t h : {8}) {
  377. for (size_t w : {8, 11}) {
  378. for (size_t kernel_size : {3, 4, 5, 7}) {
  379. for (int p : {0, static_cast<int>(kernel_size / 2)}) {
  380. for (size_t s : {2}) {
  381. if (kernel_size >= 7) {
  382. b = std::min(b, 32_z);
  383. }
  384. size_t f = kernel_size;
  385. cur_param.mode = mode;
  386. cur_param.format = param::Convolution::Format::NCHW4;
  387. cur_param.sparse = param::Convolution::Sparse::DENSE;
  388. cur_param.pad_h = cur_param.pad_w = p;
  389. cur_param.stride_h = cur_param.stride_w = s;
  390. //! bias channel
  391. args.emplace_back(cur_param, TensorShape{b, ic / 4, h, w, 4},
  392. TensorShape{oc, ic / 4, f, f, 4});
  393. } } } } } } } } }
  394. // clang-format on
  395. cur_param.pad_h = cur_param.pad_w = 1;
  396. cur_param.stride_h = cur_param.stride_w = 1;
  397. args.emplace_back(
  398. cur_param, TensorShape{16, 4, 8, 11, 4}, TensorShape{16, 4, 3, 3, 4});
  399. return args;
  400. }
  401. std::vector<TestArg> convolution::get_args_int8_nchw_conv_bwd_data() {
  402. std::vector<TestArg> args;
  403. param::Convolution cur_param;
  404. // clang-format off
  405. for (auto mode : {param::Convolution::Mode::CROSS_CORRELATION}) {
  406. for (size_t b : {64, 16}) {
  407. for (size_t ic : {16, 32}) {
  408. for (size_t oc : {16, 32}) {
  409. for (size_t h : {8}) {
  410. for (size_t w : {8, 11}) {
  411. for (size_t kernel_size : {3, 4, 5, 7}) {
  412. for (int p : {0, static_cast<int>(kernel_size / 2)}) {
  413. for (size_t s : {2}) {
  414. if (kernel_size >= 7) {
  415. b = std::min(b, 32_z);
  416. }
  417. size_t f = kernel_size;
  418. cur_param.mode = mode;
  419. cur_param.format = param::Convolution::Format::NCHW;
  420. cur_param.sparse = param::Convolution::Sparse::DENSE;
  421. cur_param.pad_h = cur_param.pad_w = p;
  422. cur_param.stride_h = cur_param.stride_w = s;
  423. //! bias channel
  424. args.emplace_back(cur_param, TensorShape{b, ic, h, w},
  425. TensorShape{oc, ic, f, f});
  426. } } } } } } } } }
  427. // clang-format on
  428. // test stride = 1
  429. cur_param.pad_h = cur_param.pad_w = 1;
  430. cur_param.stride_h = cur_param.stride_w = 1;
  431. args.emplace_back(cur_param, TensorShape{16, 16, 8, 11}, TensorShape{16, 16, 3, 3});
  432. return args;
  433. }
  434. std::vector<TestArg> convolution::get_args_int8_nhwc_conv_bwd_data() {
  435. std::vector<TestArg> args;
  436. param::Convolution cur_param;
  437. // clang-format off
  438. for (auto mode : {param::Convolution::Mode::CROSS_CORRELATION}) {
  439. for (size_t b : {64, 16}) {
  440. for (size_t ic : {16, 32}) {
  441. for (size_t oc : {16, 32}) {
  442. for (size_t h : {8}) {
  443. for (size_t w : {8, 11}) {
  444. for (size_t kernel_size : {3, 4, 5, 7}) {
  445. for (int p : {0, static_cast<int>(kernel_size / 2)}) {
  446. for (size_t s : {2}) {
  447. if (kernel_size >= 7) {
  448. b = std::min(b, 32_z);
  449. }
  450. size_t f = kernel_size;
  451. cur_param.mode = mode;
  452. cur_param.format = param::Convolution::Format::NHWC;
  453. cur_param.sparse = param::Convolution::Sparse::DENSE;
  454. cur_param.pad_h = cur_param.pad_w = p;
  455. cur_param.stride_h = cur_param.stride_w = s;
  456. //! bias channel
  457. args.emplace_back(cur_param, TensorShape{b, h, w, ic},
  458. TensorShape{oc, f, f, ic});
  459. } } } } } } } } }
  460. // clang-format on
  461. cur_param.pad_h = cur_param.pad_w = 1;
  462. cur_param.stride_h = cur_param.stride_w = 1;
  463. args.emplace_back(cur_param, TensorShape{16, 8, 11, 16}, TensorShape{16, 3, 3, 16});
  464. return args;
  465. }
  466. void convolution::test_conv_config_combinations(
  467. int k_size, Handle* handle, bool test_int8, bool test_backward, bool is_cuda,
  468. ConvEPSGetter eps_getter, bool use_io16xc32) {
  469. Checker<Convolution> checker(handle);
  470. std::unique_ptr<Checker<ConvolutionBackwardData>> checker_bwd_data_ptr;
  471. std::unique_ptr<Checker<ConvolutionBackwardFilter>> checker_bwd_filter_ptr;
  472. if (test_backward) {
  473. checker_bwd_data_ptr.reset(
  474. new std::remove_reference<decltype(*checker_bwd_data_ptr)>::type(
  475. handle));
  476. checker_bwd_filter_ptr.reset(
  477. new std::remove_reference<decltype(*checker_bwd_filter_ptr)>::type(
  478. handle));
  479. }
  480. auto&& checker_bwd_data = *checker_bwd_data_ptr;
  481. auto&& checker_bwd_filter = *checker_bwd_filter_ptr;
  482. #define CONF_BOOL(var) for (int var : {0, 1})
  483. std::unordered_set<Convolution::AlgorithmDesc> used_algos;
  484. std::unordered_set<ConvolutionBackwardData::AlgorithmDesc> used_algos_bwd_data;
  485. std::unordered_set<ConvolutionBackwardFilter::AlgorithmDesc> used_algos_bwd_flt;
  486. using Param = Convolution::Param;
  487. CONF_BOOL(conv)
  488. CONF_BOOL(padding)
  489. CONF_BOOL(stride)
  490. CONF_BOOL(group)
  491. CONF_BOOL(non_square)
  492. CONF_BOOL(dilation)
  493. CONF_BOOL(format)
  494. // dtype: 0: f32; 1: f16; 2: i8x8x16 3: i8x8x32
  495. for (int dtype = 0; dtype < (test_int8 ? 4 : 2); ++dtype)
  496. for (int ksize : {1, k_size}) {
  497. // When is_cuda is on, test cases where format is NHWC and
  498. // data type is not INT8x8x32 are disabled.
  499. if (is_cuda) {
  500. if (format && dtype != 3)
  501. continue;
  502. }
  503. auto config2str = [&]() -> std::string {
  504. std::ostringstream ostr;
  505. ostr << conv << padding << stride << group << non_square << dilation
  506. << format << dtype << ksize;
  507. return ostr.str();
  508. };
  509. auto errmsg = [&](const char* name) {
  510. std::string ret;
  511. ret += "checker failed for algorithm ";
  512. ret += name;
  513. ret += " with conv,padding,stride,group,non_square,dilation,format,"
  514. "dtype,ksize=";
  515. ret += config2str();
  516. return ret;
  517. };
  518. MEGDNN_MARK_USED_VAR(errmsg);
  519. Param param;
  520. param.mode =
  521. conv ? Param::Mode::CONVOLUTION : Param::Mode::CROSS_CORRELATION;
  522. param.format = format ? Param::Format::NHWC : Param::Format::NCHW;
  523. if (dtype == 1 && use_io16xc32) {
  524. param.compute_mode = Param::ComputeMode::FLOAT32;
  525. }
  526. size_t IC = 6, OC = 9, G = 3, FH = ksize, FW = ksize;
  527. TensorShape ishp = TensorShape{2, 18, 18, IC}, fshp;
  528. if (format) {
  529. ishp.shape[0] = 2;
  530. ishp.shape[1] = 18;
  531. ishp.shape[2] = 18;
  532. ishp.shape[3] = IC;
  533. } else {
  534. ishp.shape[0] = 2;
  535. ishp.shape[1] = IC;
  536. ishp.shape[2] = 18;
  537. ishp.shape[3] = 18;
  538. }
  539. if (padding) {
  540. param.pad_h = 2 + non_square;
  541. param.pad_w = 2 - non_square;
  542. }
  543. if (non_square) {
  544. if (FH > 2)
  545. FH -= 2;
  546. FW += 1;
  547. ++ishp[format ? 2 : 3];
  548. }
  549. if (group) {
  550. fshp = format ? TensorShape{G, OC / G, FH, FW, IC / G}
  551. : TensorShape{G, OC / G, IC / G, FH, FW};
  552. param.sparse = Param::Sparse::GROUP;
  553. } else {
  554. fshp = format ? TensorShape{OC, FH, FW, IC}
  555. : TensorShape{OC, IC, FH, FW};
  556. }
  557. if (dilation) {
  558. param.dilate_h = 2 - non_square;
  559. param.dilate_w = 2 + non_square;
  560. }
  561. if (stride) {
  562. param.stride_h = 2 + non_square;
  563. param.stride_w = 2 - non_square;
  564. }
  565. DType inp_type, out_type;
  566. if (dtype == 2) {
  567. inp_type = dtype::Int8();
  568. out_type = dtype::Int16();
  569. } else if (dtype == 3) {
  570. inp_type = dtype::Int8();
  571. out_type = dtype::Int32();
  572. } else {
  573. if (!dtype)
  574. inp_type = dtype::Float32();
  575. else
  576. inp_type = dtype::Float16();
  577. out_type = inp_type;
  578. }
  579. checker.set_dtype(0, inp_type)
  580. .set_dtype(1, inp_type)
  581. .set_dtype(2, out_type)
  582. .set_param(param);
  583. auto opr = checker.opr();
  584. opr->param() = param;
  585. std::string param_str;
  586. Algorithm::serialize_write_pod(opr->param(), param_str);
  587. TensorLayout ily{ishp, inp_type}, fly{fshp, inp_type}, oly;
  588. oly.dtype = out_type;
  589. opr->deduce_layout(ily, fly, oly);
  590. int channel_start = 1;
  591. if (format)
  592. channel_start = 3;
  593. float scale = 1.0f / sqrt(fshp[channel_start] * FH * FW);
  594. UniformFloatRNG rng(scale, 2 * scale);
  595. checker.set_rng(0, &rng).set_rng(1, &rng);
  596. for (auto algo : opr->get_all_algorithms_info_safe(ily, fly, oly)) {
  597. used_algos.insert(algo.desc);
  598. opr->execution_policy().algo = algo.desc;
  599. construct_sub_execution_policy_heuristic<ConvolutionForward>(
  600. opr->execution_policy(), {ily, fly, oly}, param_str,
  601. opr->handle());
  602. checker.set_epsilon(eps_getter(dtype == 1, 0, algo.desc.name.c_str()))
  603. .execs({ishp, fshp, {}});
  604. opr->execution_policy() = {};
  605. ASSERT_TRUE(checker.prev_succ()) << errmsg(algo.desc.name.c_str());
  606. }
  607. if (test_backward) {
  608. // backward data
  609. checker_bwd_data.set_dtype(0, inp_type)
  610. .set_dtype(1, out_type)
  611. .set_dtype(2, inp_type)
  612. .set_param(param);
  613. auto opr = checker_bwd_data.opr();
  614. opr->param() = param;
  615. std::string param_str;
  616. Algorithm::serialize_write_pod(opr->param(), param_str);
  617. for (auto algo : opr->get_all_algorithms_info_safe(fly, oly, ily)) {
  618. used_algos_bwd_data.insert(algo.desc);
  619. opr->execution_policy().algo = algo.desc;
  620. construct_sub_execution_policy_heuristic<ConvolutionBackwardData>(
  621. opr->execution_policy(), {fly, oly, ily}, param_str,
  622. opr->handle());
  623. checker_bwd_data
  624. .set_epsilon(
  625. eps_getter(dtype == 1, 1, algo.desc.name.c_str()))
  626. .execl({fly, oly, ily});
  627. opr->execution_policy() = {};
  628. ASSERT_TRUE(checker_bwd_data.prev_succ())
  629. << errmsg(algo.desc.name.c_str());
  630. }
  631. }
  632. if (test_backward) {
  633. // backward filter
  634. checker_bwd_filter.set_dtype(0, inp_type)
  635. .set_dtype(1, out_type)
  636. .set_dtype(2, inp_type)
  637. .set_param(param);
  638. auto opr = checker_bwd_filter.opr();
  639. opr->param() = param;
  640. std::string param_str;
  641. Algorithm::serialize_write_pod(opr->param(), param_str);
  642. for (auto algo : opr->get_all_algorithms_info_safe(ily, oly, fly)) {
  643. used_algos_bwd_flt.insert(algo.desc);
  644. opr->execution_policy().algo = algo.desc;
  645. construct_sub_execution_policy_heuristic<ConvolutionBackwardFilter>(
  646. opr->execution_policy(), {ily, oly, fly}, param_str,
  647. opr->handle());
  648. checker_bwd_filter
  649. .set_epsilon(
  650. eps_getter(dtype == 1, 2, algo.desc.name.c_str()))
  651. .execl({ily, oly, fly});
  652. opr->execution_policy() = {};
  653. ASSERT_TRUE(checker_bwd_filter.prev_succ())
  654. << errmsg(algo.desc.name.c_str());
  655. }
  656. }
  657. }
  658. }
  659. // vim: syntax=cpp.doxygen