You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

padding.h 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. #pragma once
  2. #include <cstddef>
  3. #include <iostream>
  4. #include "megdnn/basic_types.h"
  5. #include "megdnn/opr_param_defs.h"
  6. namespace megdnn {
  7. namespace test {
  8. namespace padding {
  9. struct TestArg {
  10. param::Padding param;
  11. TensorShape src;
  12. TensorShape dst;
  13. TestArg(param::Padding _param, TensorShape _src, TensorShape _dst)
  14. : param(_param), src(_src), dst(_dst) {}
  15. };
  16. inline std::vector<TestArg> get_args() {
  17. size_t src_shape_dim0 = 5;
  18. size_t src_shape_dim1 = 5;
  19. size_t src_shape_dim2 = 5;
  20. size_t src_shape_dim3 = 5;
  21. size_t src_shape_dim4 = 5;
  22. size_t src_shape_dim5 = 5;
  23. size_t src_shape_dim6 = 5;
  24. size_t dst_shape_dim0 = 8;
  25. size_t dst_shape_dim1 = 8;
  26. size_t dst_shape_dim2 = 8;
  27. size_t dst_shape_dim3 = 8;
  28. size_t dst_shape_dim4 = 8;
  29. size_t dst_shape_dim5 = 8;
  30. size_t dst_shape_dim6 = 8;
  31. std::vector<TestArg> args;
  32. param::Padding cur_param;
  33. cur_param.front_offset_dim0 = 0;
  34. cur_param.front_offset_dim1 = 0;
  35. cur_param.front_offset_dim2 = 0;
  36. cur_param.front_offset_dim3 = 0;
  37. cur_param.front_offset_dim4 = 0;
  38. cur_param.front_offset_dim5 = 0;
  39. cur_param.front_offset_dim6 = 0;
  40. cur_param.back_offset_dim0 = 0;
  41. cur_param.back_offset_dim1 = 0;
  42. cur_param.back_offset_dim2 = 0;
  43. cur_param.back_offset_dim3 = 0;
  44. cur_param.back_offset_dim4 = 0;
  45. cur_param.back_offset_dim5 = 0;
  46. cur_param.back_offset_dim6 = 0;
  47. cur_param.padding_val = 2;
  48. cur_param.front_offset_dim0 = 1;
  49. cur_param.back_offset_dim0 = 2;
  50. cur_param.padding_mode = param::Padding::PaddingMode::CONSTANT;
  51. args.emplace_back(
  52. cur_param, TensorShape{src_shape_dim0}, TensorShape{dst_shape_dim0});
  53. cur_param.padding_mode = param::Padding::PaddingMode::REPLICATE;
  54. args.emplace_back(
  55. cur_param, TensorShape{src_shape_dim0}, TensorShape{dst_shape_dim0});
  56. cur_param.padding_mode = param::Padding::PaddingMode::REFLECT;
  57. args.emplace_back(
  58. cur_param, TensorShape{src_shape_dim0}, TensorShape{dst_shape_dim0});
  59. cur_param.front_offset_dim1 = 2;
  60. cur_param.back_offset_dim1 = 1;
  61. cur_param.padding_mode = param::Padding::PaddingMode::CONSTANT;
  62. args.emplace_back(
  63. cur_param, TensorShape{src_shape_dim0, src_shape_dim1},
  64. TensorShape{dst_shape_dim0, dst_shape_dim1});
  65. cur_param.padding_mode = param::Padding::PaddingMode::REPLICATE;
  66. args.emplace_back(
  67. cur_param, TensorShape{src_shape_dim0, src_shape_dim1},
  68. TensorShape{dst_shape_dim0, dst_shape_dim1});
  69. cur_param.padding_mode = param::Padding::PaddingMode::REFLECT;
  70. args.emplace_back(
  71. cur_param, TensorShape{src_shape_dim0, src_shape_dim1},
  72. TensorShape{dst_shape_dim0, dst_shape_dim1});
  73. cur_param.front_offset_dim2 = 1;
  74. cur_param.back_offset_dim2 = 2;
  75. cur_param.padding_mode = param::Padding::PaddingMode::CONSTANT;
  76. args.emplace_back(
  77. cur_param, TensorShape{src_shape_dim0, src_shape_dim1, src_shape_dim2},
  78. TensorShape{dst_shape_dim0, dst_shape_dim1, dst_shape_dim2});
  79. cur_param.padding_mode = param::Padding::PaddingMode::REPLICATE;
  80. args.emplace_back(
  81. cur_param, TensorShape{src_shape_dim0, src_shape_dim1, src_shape_dim2},
  82. TensorShape{dst_shape_dim0, dst_shape_dim1, dst_shape_dim2});
  83. cur_param.padding_mode = param::Padding::PaddingMode::REFLECT;
  84. args.emplace_back(
  85. cur_param, TensorShape{src_shape_dim0, src_shape_dim1, src_shape_dim2},
  86. TensorShape{dst_shape_dim0, dst_shape_dim1, dst_shape_dim2});
  87. cur_param.front_offset_dim3 = 0;
  88. cur_param.back_offset_dim3 = 3;
  89. cur_param.padding_mode = param::Padding::PaddingMode::CONSTANT;
  90. args.emplace_back(
  91. cur_param,
  92. TensorShape{src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3},
  93. TensorShape{
  94. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3});
  95. cur_param.padding_mode = param::Padding::PaddingMode::REPLICATE;
  96. args.emplace_back(
  97. cur_param,
  98. TensorShape{src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3},
  99. TensorShape{
  100. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3});
  101. cur_param.padding_mode = param::Padding::PaddingMode::REFLECT;
  102. args.emplace_back(
  103. cur_param,
  104. TensorShape{src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3},
  105. TensorShape{
  106. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3});
  107. cur_param.front_offset_dim4 = 3;
  108. cur_param.back_offset_dim4 = 0;
  109. cur_param.padding_mode = param::Padding::PaddingMode::CONSTANT;
  110. args.emplace_back(
  111. cur_param,
  112. TensorShape{
  113. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  114. src_shape_dim4},
  115. TensorShape{
  116. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  117. dst_shape_dim4});
  118. cur_param.padding_mode = param::Padding::PaddingMode::REPLICATE;
  119. args.emplace_back(
  120. cur_param,
  121. TensorShape{
  122. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  123. src_shape_dim4},
  124. TensorShape{
  125. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  126. dst_shape_dim4});
  127. cur_param.padding_mode = param::Padding::PaddingMode::REFLECT;
  128. args.emplace_back(
  129. cur_param,
  130. TensorShape{
  131. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  132. src_shape_dim4},
  133. TensorShape{
  134. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  135. dst_shape_dim4});
  136. cur_param.front_offset_dim5 = 1;
  137. cur_param.back_offset_dim5 = 2;
  138. cur_param.padding_mode = param::Padding::PaddingMode::CONSTANT;
  139. args.emplace_back(
  140. cur_param,
  141. TensorShape{
  142. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  143. src_shape_dim4, src_shape_dim5},
  144. TensorShape{
  145. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  146. dst_shape_dim4, dst_shape_dim5});
  147. cur_param.padding_mode = param::Padding::PaddingMode::REPLICATE;
  148. args.emplace_back(
  149. cur_param,
  150. TensorShape{
  151. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  152. src_shape_dim4, src_shape_dim5},
  153. TensorShape{
  154. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  155. dst_shape_dim4, dst_shape_dim5});
  156. cur_param.padding_mode = param::Padding::PaddingMode::REFLECT;
  157. args.emplace_back(
  158. cur_param,
  159. TensorShape{
  160. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  161. src_shape_dim4, src_shape_dim5},
  162. TensorShape{
  163. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  164. dst_shape_dim4, dst_shape_dim5});
  165. cur_param.front_offset_dim6 = 0;
  166. cur_param.front_offset_dim6 = 3;
  167. cur_param.padding_mode = param::Padding::PaddingMode::CONSTANT;
  168. args.emplace_back(
  169. cur_param,
  170. TensorShape{
  171. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  172. src_shape_dim4, src_shape_dim5, src_shape_dim6},
  173. TensorShape{
  174. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  175. dst_shape_dim4, dst_shape_dim5, dst_shape_dim6});
  176. cur_param.padding_mode = param::Padding::PaddingMode::REPLICATE;
  177. args.emplace_back(
  178. cur_param,
  179. TensorShape{
  180. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  181. src_shape_dim4, src_shape_dim5, src_shape_dim6},
  182. TensorShape{
  183. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  184. dst_shape_dim4, dst_shape_dim5, dst_shape_dim6});
  185. cur_param.padding_mode = param::Padding::PaddingMode::REFLECT;
  186. args.emplace_back(
  187. cur_param,
  188. TensorShape{
  189. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  190. src_shape_dim4, src_shape_dim5, src_shape_dim6},
  191. TensorShape{
  192. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  193. dst_shape_dim4, dst_shape_dim5, dst_shape_dim6});
  194. return args;
  195. }
  196. inline std::vector<TestArg> get_args_backward() {
  197. size_t src_shape_dim0 = 8;
  198. size_t src_shape_dim1 = 8;
  199. size_t src_shape_dim2 = 8;
  200. size_t src_shape_dim3 = 8;
  201. size_t src_shape_dim4 = 8;
  202. size_t src_shape_dim5 = 8;
  203. size_t src_shape_dim6 = 8;
  204. size_t dst_shape_dim0 = 5;
  205. size_t dst_shape_dim1 = 5;
  206. size_t dst_shape_dim2 = 5;
  207. size_t dst_shape_dim3 = 5;
  208. size_t dst_shape_dim4 = 5;
  209. size_t dst_shape_dim5 = 5;
  210. size_t dst_shape_dim6 = 5;
  211. std::vector<TestArg> args;
  212. param::Padding cur_param;
  213. cur_param.front_offset_dim0 = 0;
  214. cur_param.front_offset_dim1 = 0;
  215. cur_param.front_offset_dim2 = 0;
  216. cur_param.front_offset_dim3 = 0;
  217. cur_param.front_offset_dim4 = 0;
  218. cur_param.front_offset_dim5 = 0;
  219. cur_param.front_offset_dim6 = 0;
  220. cur_param.back_offset_dim0 = 0;
  221. cur_param.back_offset_dim1 = 0;
  222. cur_param.back_offset_dim2 = 0;
  223. cur_param.back_offset_dim3 = 0;
  224. cur_param.back_offset_dim4 = 0;
  225. cur_param.back_offset_dim5 = 0;
  226. cur_param.back_offset_dim6 = 0;
  227. cur_param.padding_val = 2;
  228. cur_param.front_offset_dim0 = 1;
  229. cur_param.back_offset_dim0 = 2;
  230. cur_param.padding_mode = param::Padding::PaddingMode::CONSTANT;
  231. args.emplace_back(
  232. cur_param, TensorShape{src_shape_dim0}, TensorShape{dst_shape_dim0});
  233. cur_param.padding_mode = param::Padding::PaddingMode::REPLICATE;
  234. args.emplace_back(
  235. cur_param, TensorShape{src_shape_dim0}, TensorShape{dst_shape_dim0});
  236. cur_param.padding_mode = param::Padding::PaddingMode::REFLECT;
  237. args.emplace_back(
  238. cur_param, TensorShape{src_shape_dim0}, TensorShape{dst_shape_dim0});
  239. cur_param.front_offset_dim1 = 2;
  240. cur_param.back_offset_dim1 = 1;
  241. cur_param.padding_mode = param::Padding::PaddingMode::CONSTANT;
  242. args.emplace_back(
  243. cur_param, TensorShape{src_shape_dim0, src_shape_dim1},
  244. TensorShape{dst_shape_dim0, dst_shape_dim1});
  245. cur_param.padding_mode = param::Padding::PaddingMode::REPLICATE;
  246. args.emplace_back(
  247. cur_param, TensorShape{src_shape_dim0, src_shape_dim1},
  248. TensorShape{dst_shape_dim0, dst_shape_dim1});
  249. cur_param.padding_mode = param::Padding::PaddingMode::REFLECT;
  250. args.emplace_back(
  251. cur_param, TensorShape{src_shape_dim0, src_shape_dim1},
  252. TensorShape{dst_shape_dim0, dst_shape_dim1});
  253. cur_param.front_offset_dim2 = 1;
  254. cur_param.back_offset_dim2 = 2;
  255. cur_param.padding_mode = param::Padding::PaddingMode::CONSTANT;
  256. args.emplace_back(
  257. cur_param, TensorShape{src_shape_dim0, src_shape_dim1, src_shape_dim2},
  258. TensorShape{dst_shape_dim0, dst_shape_dim1, dst_shape_dim2});
  259. cur_param.padding_mode = param::Padding::PaddingMode::REPLICATE;
  260. args.emplace_back(
  261. cur_param, TensorShape{src_shape_dim0, src_shape_dim1, src_shape_dim2},
  262. TensorShape{dst_shape_dim0, dst_shape_dim1, dst_shape_dim2});
  263. cur_param.padding_mode = param::Padding::PaddingMode::REFLECT;
  264. args.emplace_back(
  265. cur_param, TensorShape{src_shape_dim0, src_shape_dim1, src_shape_dim2},
  266. TensorShape{dst_shape_dim0, dst_shape_dim1, dst_shape_dim2});
  267. cur_param.front_offset_dim3 = 0;
  268. cur_param.back_offset_dim3 = 3;
  269. cur_param.padding_mode = param::Padding::PaddingMode::CONSTANT;
  270. args.emplace_back(
  271. cur_param,
  272. TensorShape{src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3},
  273. TensorShape{
  274. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3});
  275. cur_param.padding_mode = param::Padding::PaddingMode::REPLICATE;
  276. args.emplace_back(
  277. cur_param,
  278. TensorShape{src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3},
  279. TensorShape{
  280. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3});
  281. cur_param.padding_mode = param::Padding::PaddingMode::REFLECT;
  282. args.emplace_back(
  283. cur_param,
  284. TensorShape{src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3},
  285. TensorShape{
  286. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3});
  287. cur_param.front_offset_dim4 = 3;
  288. cur_param.back_offset_dim4 = 0;
  289. cur_param.padding_mode = param::Padding::PaddingMode::CONSTANT;
  290. args.emplace_back(
  291. cur_param,
  292. TensorShape{
  293. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  294. src_shape_dim4},
  295. TensorShape{
  296. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  297. dst_shape_dim4});
  298. cur_param.padding_mode = param::Padding::PaddingMode::REPLICATE;
  299. args.emplace_back(
  300. cur_param,
  301. TensorShape{
  302. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  303. src_shape_dim4},
  304. TensorShape{
  305. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  306. dst_shape_dim4});
  307. cur_param.padding_mode = param::Padding::PaddingMode::REFLECT;
  308. args.emplace_back(
  309. cur_param,
  310. TensorShape{
  311. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  312. src_shape_dim4},
  313. TensorShape{
  314. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  315. dst_shape_dim4});
  316. cur_param.front_offset_dim5 = 1;
  317. cur_param.back_offset_dim5 = 2;
  318. cur_param.padding_mode = param::Padding::PaddingMode::CONSTANT;
  319. args.emplace_back(
  320. cur_param,
  321. TensorShape{
  322. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  323. src_shape_dim4, src_shape_dim5},
  324. TensorShape{
  325. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  326. dst_shape_dim4, dst_shape_dim5});
  327. cur_param.padding_mode = param::Padding::PaddingMode::REPLICATE;
  328. args.emplace_back(
  329. cur_param,
  330. TensorShape{
  331. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  332. src_shape_dim4, src_shape_dim5},
  333. TensorShape{
  334. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  335. dst_shape_dim4, dst_shape_dim5});
  336. cur_param.padding_mode = param::Padding::PaddingMode::REFLECT;
  337. args.emplace_back(
  338. cur_param,
  339. TensorShape{
  340. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  341. src_shape_dim4, src_shape_dim5},
  342. TensorShape{
  343. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  344. dst_shape_dim4, dst_shape_dim5});
  345. cur_param.front_offset_dim6 = 0;
  346. cur_param.back_offset_dim6 = 3;
  347. cur_param.padding_mode = param::Padding::PaddingMode::CONSTANT;
  348. args.emplace_back(
  349. cur_param,
  350. TensorShape{
  351. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  352. src_shape_dim4, src_shape_dim5, src_shape_dim6},
  353. TensorShape{
  354. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  355. dst_shape_dim4, dst_shape_dim5, dst_shape_dim6});
  356. cur_param.padding_mode = param::Padding::PaddingMode::REPLICATE;
  357. args.emplace_back(
  358. cur_param,
  359. TensorShape{
  360. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  361. src_shape_dim4, src_shape_dim5, src_shape_dim6},
  362. TensorShape{
  363. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  364. dst_shape_dim4, dst_shape_dim5, dst_shape_dim6});
  365. cur_param.padding_mode = param::Padding::PaddingMode::REFLECT;
  366. args.emplace_back(
  367. cur_param,
  368. TensorShape{
  369. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  370. src_shape_dim4, src_shape_dim5, src_shape_dim6},
  371. TensorShape{
  372. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  373. dst_shape_dim4, dst_shape_dim5, dst_shape_dim6});
  374. return args;
  375. }
  376. } // namespace padding
  377. } // namespace test
  378. } // namespace megdnn