You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

padding.h 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. /**
  2. * \file dnn/test/common/padding.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include <cstddef>
  14. #include <iostream>
  15. #include "megdnn/basic_types.h"
  16. #include "megdnn/opr_param_defs.h"
  17. namespace megdnn {
  18. namespace test {
  19. namespace padding {
  20. struct TestArg {
  21. param::Padding param;
  22. TensorShape src;
  23. TensorShape dst;
  24. TestArg(param::Padding _param, TensorShape _src, TensorShape _dst)
  25. : param(_param), src(_src), dst(_dst) {}
  26. };
  27. inline std::vector<TestArg> get_args() {
  28. size_t src_shape_dim0 = 5;
  29. size_t src_shape_dim1 = 5;
  30. size_t src_shape_dim2 = 5;
  31. size_t src_shape_dim3 = 5;
  32. size_t src_shape_dim4 = 5;
  33. size_t src_shape_dim5 = 5;
  34. size_t src_shape_dim6 = 5;
  35. size_t dst_shape_dim0 = 8;
  36. size_t dst_shape_dim1 = 8;
  37. size_t dst_shape_dim2 = 8;
  38. size_t dst_shape_dim3 = 8;
  39. size_t dst_shape_dim4 = 8;
  40. size_t dst_shape_dim5 = 8;
  41. size_t dst_shape_dim6 = 8;
  42. std::vector<TestArg> args;
  43. param::Padding cur_param;
  44. cur_param.front_offset_dim0 = 0;
  45. cur_param.front_offset_dim1 = 0;
  46. cur_param.front_offset_dim2 = 0;
  47. cur_param.front_offset_dim3 = 0;
  48. cur_param.front_offset_dim4 = 0;
  49. cur_param.front_offset_dim5 = 0;
  50. cur_param.front_offset_dim6 = 0;
  51. cur_param.back_offset_dim0 = 0;
  52. cur_param.back_offset_dim1 = 0;
  53. cur_param.back_offset_dim2 = 0;
  54. cur_param.back_offset_dim3 = 0;
  55. cur_param.back_offset_dim4 = 0;
  56. cur_param.back_offset_dim5 = 0;
  57. cur_param.back_offset_dim6 = 0;
  58. cur_param.padding_val = 2;
  59. cur_param.front_offset_dim0 = 1;
  60. cur_param.back_offset_dim0 = 2;
  61. cur_param.padding_mode = param::Padding::PaddingMode::CONSTANT;
  62. args.emplace_back(
  63. cur_param, TensorShape{src_shape_dim0}, TensorShape{dst_shape_dim0});
  64. cur_param.padding_mode = param::Padding::PaddingMode::REPLICATE;
  65. args.emplace_back(
  66. cur_param, TensorShape{src_shape_dim0}, TensorShape{dst_shape_dim0});
  67. cur_param.padding_mode = param::Padding::PaddingMode::REFLECT;
  68. args.emplace_back(
  69. cur_param, TensorShape{src_shape_dim0}, TensorShape{dst_shape_dim0});
  70. cur_param.front_offset_dim1 = 2;
  71. cur_param.back_offset_dim1 = 1;
  72. cur_param.padding_mode = param::Padding::PaddingMode::CONSTANT;
  73. args.emplace_back(
  74. cur_param, TensorShape{src_shape_dim0, src_shape_dim1},
  75. TensorShape{dst_shape_dim0, dst_shape_dim1});
  76. cur_param.padding_mode = param::Padding::PaddingMode::REPLICATE;
  77. args.emplace_back(
  78. cur_param, TensorShape{src_shape_dim0, src_shape_dim1},
  79. TensorShape{dst_shape_dim0, dst_shape_dim1});
  80. cur_param.padding_mode = param::Padding::PaddingMode::REFLECT;
  81. args.emplace_back(
  82. cur_param, TensorShape{src_shape_dim0, src_shape_dim1},
  83. TensorShape{dst_shape_dim0, dst_shape_dim1});
  84. cur_param.front_offset_dim2 = 1;
  85. cur_param.back_offset_dim2 = 2;
  86. cur_param.padding_mode = param::Padding::PaddingMode::CONSTANT;
  87. args.emplace_back(
  88. cur_param, TensorShape{src_shape_dim0, src_shape_dim1, src_shape_dim2},
  89. TensorShape{dst_shape_dim0, dst_shape_dim1, dst_shape_dim2});
  90. cur_param.padding_mode = param::Padding::PaddingMode::REPLICATE;
  91. args.emplace_back(
  92. cur_param, TensorShape{src_shape_dim0, src_shape_dim1, src_shape_dim2},
  93. TensorShape{dst_shape_dim0, dst_shape_dim1, dst_shape_dim2});
  94. cur_param.padding_mode = param::Padding::PaddingMode::REFLECT;
  95. args.emplace_back(
  96. cur_param, TensorShape{src_shape_dim0, src_shape_dim1, src_shape_dim2},
  97. TensorShape{dst_shape_dim0, dst_shape_dim1, dst_shape_dim2});
  98. cur_param.front_offset_dim3 = 0;
  99. cur_param.back_offset_dim3 = 3;
  100. cur_param.padding_mode = param::Padding::PaddingMode::CONSTANT;
  101. args.emplace_back(
  102. cur_param,
  103. TensorShape{src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3},
  104. TensorShape{
  105. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3});
  106. cur_param.padding_mode = param::Padding::PaddingMode::REPLICATE;
  107. args.emplace_back(
  108. cur_param,
  109. TensorShape{src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3},
  110. TensorShape{
  111. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3});
  112. cur_param.padding_mode = param::Padding::PaddingMode::REFLECT;
  113. args.emplace_back(
  114. cur_param,
  115. TensorShape{src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3},
  116. TensorShape{
  117. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3});
  118. cur_param.front_offset_dim4 = 3;
  119. cur_param.back_offset_dim4 = 0;
  120. cur_param.padding_mode = param::Padding::PaddingMode::CONSTANT;
  121. args.emplace_back(
  122. cur_param,
  123. TensorShape{
  124. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  125. src_shape_dim4},
  126. TensorShape{
  127. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  128. dst_shape_dim4});
  129. cur_param.padding_mode = param::Padding::PaddingMode::REPLICATE;
  130. args.emplace_back(
  131. cur_param,
  132. TensorShape{
  133. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  134. src_shape_dim4},
  135. TensorShape{
  136. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  137. dst_shape_dim4});
  138. cur_param.padding_mode = param::Padding::PaddingMode::REFLECT;
  139. args.emplace_back(
  140. cur_param,
  141. TensorShape{
  142. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  143. src_shape_dim4},
  144. TensorShape{
  145. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  146. dst_shape_dim4});
  147. cur_param.front_offset_dim5 = 1;
  148. cur_param.back_offset_dim5 = 2;
  149. cur_param.padding_mode = param::Padding::PaddingMode::CONSTANT;
  150. args.emplace_back(
  151. cur_param,
  152. TensorShape{
  153. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  154. src_shape_dim4, src_shape_dim5},
  155. TensorShape{
  156. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  157. dst_shape_dim4, dst_shape_dim5});
  158. cur_param.padding_mode = param::Padding::PaddingMode::REPLICATE;
  159. args.emplace_back(
  160. cur_param,
  161. TensorShape{
  162. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  163. src_shape_dim4, src_shape_dim5},
  164. TensorShape{
  165. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  166. dst_shape_dim4, dst_shape_dim5});
  167. cur_param.padding_mode = param::Padding::PaddingMode::REFLECT;
  168. args.emplace_back(
  169. cur_param,
  170. TensorShape{
  171. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  172. src_shape_dim4, src_shape_dim5},
  173. TensorShape{
  174. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  175. dst_shape_dim4, dst_shape_dim5});
  176. cur_param.front_offset_dim6 = 0;
  177. cur_param.front_offset_dim6 = 3;
  178. cur_param.padding_mode = param::Padding::PaddingMode::CONSTANT;
  179. args.emplace_back(
  180. cur_param,
  181. TensorShape{
  182. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  183. src_shape_dim4, src_shape_dim5, src_shape_dim6},
  184. TensorShape{
  185. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  186. dst_shape_dim4, dst_shape_dim5, dst_shape_dim6});
  187. cur_param.padding_mode = param::Padding::PaddingMode::REPLICATE;
  188. args.emplace_back(
  189. cur_param,
  190. TensorShape{
  191. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  192. src_shape_dim4, src_shape_dim5, src_shape_dim6},
  193. TensorShape{
  194. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  195. dst_shape_dim4, dst_shape_dim5, dst_shape_dim6});
  196. cur_param.padding_mode = param::Padding::PaddingMode::REFLECT;
  197. args.emplace_back(
  198. cur_param,
  199. TensorShape{
  200. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  201. src_shape_dim4, src_shape_dim5, src_shape_dim6},
  202. TensorShape{
  203. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  204. dst_shape_dim4, dst_shape_dim5, dst_shape_dim6});
  205. return args;
  206. }
  207. inline std::vector<TestArg> get_args_backward() {
  208. size_t src_shape_dim0 = 8;
  209. size_t src_shape_dim1 = 8;
  210. size_t src_shape_dim2 = 8;
  211. size_t src_shape_dim3 = 8;
  212. size_t src_shape_dim4 = 8;
  213. size_t src_shape_dim5 = 8;
  214. size_t src_shape_dim6 = 8;
  215. size_t dst_shape_dim0 = 5;
  216. size_t dst_shape_dim1 = 5;
  217. size_t dst_shape_dim2 = 5;
  218. size_t dst_shape_dim3 = 5;
  219. size_t dst_shape_dim4 = 5;
  220. size_t dst_shape_dim5 = 5;
  221. size_t dst_shape_dim6 = 5;
  222. std::vector<TestArg> args;
  223. param::Padding cur_param;
  224. cur_param.front_offset_dim0 = 0;
  225. cur_param.front_offset_dim1 = 0;
  226. cur_param.front_offset_dim2 = 0;
  227. cur_param.front_offset_dim3 = 0;
  228. cur_param.front_offset_dim4 = 0;
  229. cur_param.front_offset_dim5 = 0;
  230. cur_param.front_offset_dim6 = 0;
  231. cur_param.back_offset_dim0 = 0;
  232. cur_param.back_offset_dim1 = 0;
  233. cur_param.back_offset_dim2 = 0;
  234. cur_param.back_offset_dim3 = 0;
  235. cur_param.back_offset_dim4 = 0;
  236. cur_param.back_offset_dim5 = 0;
  237. cur_param.back_offset_dim6 = 0;
  238. cur_param.padding_val = 2;
  239. cur_param.front_offset_dim0 = 1;
  240. cur_param.back_offset_dim0 = 2;
  241. cur_param.padding_mode = param::Padding::PaddingMode::CONSTANT;
  242. args.emplace_back(
  243. cur_param, TensorShape{src_shape_dim0}, TensorShape{dst_shape_dim0});
  244. cur_param.padding_mode = param::Padding::PaddingMode::REPLICATE;
  245. args.emplace_back(
  246. cur_param, TensorShape{src_shape_dim0}, TensorShape{dst_shape_dim0});
  247. cur_param.padding_mode = param::Padding::PaddingMode::REFLECT;
  248. args.emplace_back(
  249. cur_param, TensorShape{src_shape_dim0}, TensorShape{dst_shape_dim0});
  250. cur_param.front_offset_dim1 = 2;
  251. cur_param.back_offset_dim1 = 1;
  252. cur_param.padding_mode = param::Padding::PaddingMode::CONSTANT;
  253. args.emplace_back(
  254. cur_param, TensorShape{src_shape_dim0, src_shape_dim1},
  255. TensorShape{dst_shape_dim0, dst_shape_dim1});
  256. cur_param.padding_mode = param::Padding::PaddingMode::REPLICATE;
  257. args.emplace_back(
  258. cur_param, TensorShape{src_shape_dim0, src_shape_dim1},
  259. TensorShape{dst_shape_dim0, dst_shape_dim1});
  260. cur_param.padding_mode = param::Padding::PaddingMode::REFLECT;
  261. args.emplace_back(
  262. cur_param, TensorShape{src_shape_dim0, src_shape_dim1},
  263. TensorShape{dst_shape_dim0, dst_shape_dim1});
  264. cur_param.front_offset_dim2 = 1;
  265. cur_param.back_offset_dim2 = 2;
  266. cur_param.padding_mode = param::Padding::PaddingMode::CONSTANT;
  267. args.emplace_back(
  268. cur_param, TensorShape{src_shape_dim0, src_shape_dim1, src_shape_dim2},
  269. TensorShape{dst_shape_dim0, dst_shape_dim1, dst_shape_dim2});
  270. cur_param.padding_mode = param::Padding::PaddingMode::REPLICATE;
  271. args.emplace_back(
  272. cur_param, TensorShape{src_shape_dim0, src_shape_dim1, src_shape_dim2},
  273. TensorShape{dst_shape_dim0, dst_shape_dim1, dst_shape_dim2});
  274. cur_param.padding_mode = param::Padding::PaddingMode::REFLECT;
  275. args.emplace_back(
  276. cur_param, TensorShape{src_shape_dim0, src_shape_dim1, src_shape_dim2},
  277. TensorShape{dst_shape_dim0, dst_shape_dim1, dst_shape_dim2});
  278. cur_param.front_offset_dim3 = 0;
  279. cur_param.back_offset_dim3 = 3;
  280. cur_param.padding_mode = param::Padding::PaddingMode::CONSTANT;
  281. args.emplace_back(
  282. cur_param,
  283. TensorShape{src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3},
  284. TensorShape{
  285. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3});
  286. cur_param.padding_mode = param::Padding::PaddingMode::REPLICATE;
  287. args.emplace_back(
  288. cur_param,
  289. TensorShape{src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3},
  290. TensorShape{
  291. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3});
  292. cur_param.padding_mode = param::Padding::PaddingMode::REFLECT;
  293. args.emplace_back(
  294. cur_param,
  295. TensorShape{src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3},
  296. TensorShape{
  297. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3});
  298. cur_param.front_offset_dim4 = 3;
  299. cur_param.back_offset_dim4 = 0;
  300. cur_param.padding_mode = param::Padding::PaddingMode::CONSTANT;
  301. args.emplace_back(
  302. cur_param,
  303. TensorShape{
  304. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  305. src_shape_dim4},
  306. TensorShape{
  307. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  308. dst_shape_dim4});
  309. cur_param.padding_mode = param::Padding::PaddingMode::REPLICATE;
  310. args.emplace_back(
  311. cur_param,
  312. TensorShape{
  313. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  314. src_shape_dim4},
  315. TensorShape{
  316. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  317. dst_shape_dim4});
  318. cur_param.padding_mode = param::Padding::PaddingMode::REFLECT;
  319. args.emplace_back(
  320. cur_param,
  321. TensorShape{
  322. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  323. src_shape_dim4},
  324. TensorShape{
  325. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  326. dst_shape_dim4});
  327. cur_param.front_offset_dim5 = 1;
  328. cur_param.back_offset_dim5 = 2;
  329. cur_param.padding_mode = param::Padding::PaddingMode::CONSTANT;
  330. args.emplace_back(
  331. cur_param,
  332. TensorShape{
  333. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  334. src_shape_dim4, src_shape_dim5},
  335. TensorShape{
  336. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  337. dst_shape_dim4, dst_shape_dim5});
  338. cur_param.padding_mode = param::Padding::PaddingMode::REPLICATE;
  339. args.emplace_back(
  340. cur_param,
  341. TensorShape{
  342. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  343. src_shape_dim4, src_shape_dim5},
  344. TensorShape{
  345. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  346. dst_shape_dim4, dst_shape_dim5});
  347. cur_param.padding_mode = param::Padding::PaddingMode::REFLECT;
  348. args.emplace_back(
  349. cur_param,
  350. TensorShape{
  351. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  352. src_shape_dim4, src_shape_dim5},
  353. TensorShape{
  354. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  355. dst_shape_dim4, dst_shape_dim5});
  356. cur_param.front_offset_dim6 = 0;
  357. cur_param.back_offset_dim6 = 3;
  358. cur_param.padding_mode = param::Padding::PaddingMode::CONSTANT;
  359. args.emplace_back(
  360. cur_param,
  361. TensorShape{
  362. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  363. src_shape_dim4, src_shape_dim5, src_shape_dim6},
  364. TensorShape{
  365. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  366. dst_shape_dim4, dst_shape_dim5, dst_shape_dim6});
  367. cur_param.padding_mode = param::Padding::PaddingMode::REPLICATE;
  368. args.emplace_back(
  369. cur_param,
  370. TensorShape{
  371. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  372. src_shape_dim4, src_shape_dim5, src_shape_dim6},
  373. TensorShape{
  374. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  375. dst_shape_dim4, dst_shape_dim5, dst_shape_dim6});
  376. cur_param.padding_mode = param::Padding::PaddingMode::REFLECT;
  377. args.emplace_back(
  378. cur_param,
  379. TensorShape{
  380. src_shape_dim0, src_shape_dim1, src_shape_dim2, src_shape_dim3,
  381. src_shape_dim4, src_shape_dim5, src_shape_dim6},
  382. TensorShape{
  383. dst_shape_dim0, dst_shape_dim1, dst_shape_dim2, dst_shape_dim3,
  384. dst_shape_dim4, dst_shape_dim5, dst_shape_dim6});
  385. return args;
  386. }
  387. } // namespace padding
  388. } // namespace test
  389. } // namespace megdnn

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台