You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

record1.cpp 50 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346
  1. #include "test/naive/fixture.h"
  2. #include "megdnn/oprs.h"
  3. #include "test/common/task_record_check.h"
  4. #include "test/common/adaptive_pooling.h"
  5. #include "test/common/cond_take.h"
  6. #include "test/common/convolution3d.h"
  7. #include "test/common/local.h"
  8. #include "test/common/matrix_mul.h"
  9. #include "test/common/rng.h"
  10. #include "test/common/separable_conv.h"
  11. #include "test/common/warp_affine.h"
  12. #include "test/common/warp_perspective.h"
  13. namespace {
  14. using namespace megdnn;
  15. using namespace test;
  16. class ArgmxxRNG final : public RNG {
  17. public:
  18. void gen(const TensorND& tensor) override {
  19. auto offset = tensor.layout.span().low_elem;
  20. auto nr_elems = tensor.layout.span().dist_elem();
  21. #define cb(DType) \
  22. if (tensor.layout.dtype == DType()) { \
  23. using ctype = typename DTypeTrait<DType>::ctype; \
  24. auto ptr = tensor.ptr<ctype>(); \
  25. for (size_t i = 0; i < nr_elems; ++i) { \
  26. ptr[offset + i] = i; \
  27. } \
  28. COMPAT_RANDOM(ptr + offset, ptr + offset + nr_elems); \
  29. }
  30. MEGDNN_FOREACH_COMPUTING_DTYPE(cb);
  31. #undef cb
  32. }
  33. };
  34. template <typename Argmxx>
  35. void test_argmxx() {
  36. TaskRecordChecker<Argmxx> checker(2);
  37. checker.set_dtype(1, dtype::Int32());
  38. using Param = typename Argmxx::Param;
  39. ArgmxxRNG rng;
  40. checker.set_rng(0, &rng);
  41. for (size_t axis = 0; axis < 4; ++axis) {
  42. Param param;
  43. param.axis = axis;
  44. checker.set_param(param)
  45. .set_dtype(0, dtype::Float32())
  46. .execs({{2, 3, 4, 5}, {}});
  47. checker.set_param(param)
  48. .set_dtype(0, dtype::Float16())
  49. .execs({{2, 3, 4, 5}, {}});
  50. checker.set_param(param).set_dtype(0, dtype::Int32()).execs({{2, 3, 4, 5}, {}});
  51. checker.set_param(param).set_dtype(0, dtype::Int16()).execs({{2, 3, 4, 5}, {}});
  52. checker.set_param(param).set_dtype(0, dtype::Int8()).execs({{2, 3, 4, 5}, {}});
  53. checker.set_param(param).set_dtype(0, dtype::Uint8()).execs({{2, 3, 4, 5}, {}});
  54. }
  55. checker.set_dtype(0, dtype::Float32());
  56. Param param;
  57. param.axis = 1;
  58. checker.set_param(param);
  59. // 1-step
  60. checker.execs({{2, 64, 32}, {}});
  61. // 2-step
  62. checker.execs({{2, 192, 32}, {}});
  63. // 3-step
  64. checker.execs({{2, 4333, 32}, {}});
  65. // single reduce
  66. checker.execs({{2, 1, 1}, {}});
  67. checker.execs({{2, 1 + 1, 1}, {}});
  68. checker.execs({{2, 2048 + 1, 1}, {}});
  69. checker.execs({{2, 2048 * 2048 + 1, 1}, {}});
  70. checker.execs({{2, 1 + 1, 31}, {}});
  71. checker.execs({{2, 16 + 1, 31}, {}});
  72. checker.execs({{2, 16 * 16 + 1, 31}, {}});
  73. checker.execs({{2, 16 * 16 * 16 + 1, 31}, {}});
  74. checker.execs({{2, 16 * 16 * 16 * 16 + 1, 31}, {}});
  75. checker.execs({{3, 256 * 256 + 1, 2}, {}});
  76. checker.execs({{3, 128 * 128 + 1, 3}, {}});
  77. checker.execs({{3, 64 * 64 + 1, 7}, {}});
  78. checker.execs({{3, 32 * 32 + 1, 15}, {}});
  79. checker.execs({{3, 512, 500}, {}});
  80. // very large reduce
  81. checker.execs({{1, 4194304, 1}, {}});
  82. }
  83. class ArgsortRNG final : public RNG {
  84. bool m_rev_order = false;
  85. DType m_dtype;
  86. template <typename T>
  87. void fill(T* ptr, int n) {
  88. if (m_rev_order) {
  89. for (int i = 0; i < n; ++i)
  90. ptr[i] = static_cast<T>(n / 2 - i);
  91. } else {
  92. for (int i = 0; i < n; ++i)
  93. ptr[i] = static_cast<T>(i - n / 2);
  94. COMPAT_RANDOM(ptr, ptr + n);
  95. }
  96. }
  97. void gen(const TensorND& tensor) override {
  98. auto n = tensor.layout.total_nr_elems();
  99. if (m_dtype == dtype::Float32{}) {
  100. fill(tensor.ptr<dt_float32>(), n);
  101. } else {
  102. megdnn_assert(m_dtype == dtype::Int32{});
  103. fill(tensor.ptr<dt_int32>(), n);
  104. }
  105. }
  106. public:
  107. ArgsortRNG(DType dt) : m_dtype{dt} {}
  108. void set_rev_order(bool flag) { m_rev_order = flag; }
  109. };
  110. void run_forward_test(DType dtype) {
  111. TaskRecordChecker<ArgsortForward> checker(2);
  112. using Param = Argsort::Param;
  113. using Order = Param::Order;
  114. ArgsortRNG rng{dtype};
  115. checker.set_dtype(2, dtype::Int32());
  116. checker.set_dtype(0, dtype).set_rng(0, &rng);
  117. for (size_t i = 3; i < 10240; i *= 2) {
  118. Param param;
  119. param.order = Order::ASCENDING;
  120. checker.set_param(param).execs({{3, i + 1}, {}, {}});
  121. param.order = Order::DESCENDING;
  122. checker.set_param(param).execs({{3, i - 1}, {}, {}});
  123. checker.set_param(param).execs({{13, i + 3}, {}, {}});
  124. }
  125. {
  126. // reverse sort large array
  127. constexpr size_t N = 200003;
  128. rng.set_rev_order(true);
  129. Param param;
  130. param.order = Order::ASCENDING;
  131. checker.set_param(param).execs({{1, N}, {}, {}});
  132. }
  133. }
  134. class IdxRng final : public RNG {
  135. void gen(const TensorND& tensor) override {
  136. auto ptr = tensor.ptr<dt_int32>();
  137. auto m = tensor.layout[0], n = tensor.layout[1];
  138. for (size_t i = 0; i < m; ++i) {
  139. for (size_t j = 0; j < n; ++j) {
  140. ptr[j] = j;
  141. }
  142. COMPAT_RANDOM(ptr, ptr + n);
  143. ptr += n;
  144. }
  145. }
  146. };
  147. void run_backward_test(DType dtype) {
  148. IdxRng rng;
  149. TaskRecordChecker<ArgsortBackward> checker(2);
  150. checker.set_dtype(1, dtype::Int32()).set_rng(1, &rng);
  151. checker.set_dtype(0, dtype);
  152. checker.set_dtype(2, dtype);
  153. for (size_t i = 16; i < 4096; i *= 2) {
  154. checker.execs({{3, i}, {3, i}, {3, i}});
  155. checker.execs({{3, i + 3}, {3, i + 3}, {3, i + 3}});
  156. checker.execs({{3, i + 3}, {3, i + 3}, {3, i + 7}});
  157. }
  158. }
  159. } // anonymous namespace
  160. namespace megdnn {
  161. namespace test {
  162. //! adaptive pooling
  163. TEST_F(NAIVE, ADAPTIVE_POOLING_FORWARD_RECORD) {
  164. TaskRecordChecker<AdaptivePooling> checker(2);
  165. auto args = adaptive_pooling::get_args();
  166. using Format = param::AdaptivePooling::Format;
  167. DType dtype = dtype::Float32();
  168. for (auto&& arg : args) {
  169. auto param = arg.param;
  170. auto src = arg.ishape;
  171. auto dst = arg.oshape;
  172. param.format = Format::NCHW;
  173. checker.set_epsilon(1e-2);
  174. checker.set_param(param).set_dtype(0, dtype).set_dtype(1, dtype).exec(
  175. TensorShapeArray{src, dst, {}});
  176. break;
  177. }
  178. }
  179. TEST_F(NAIVE, ADAPTIVE_POOLING_BACKWARD_RECORD) {
  180. TaskRecordChecker<AdaptivePooling> checker(2);
  181. auto args = adaptive_pooling::get_args();
  182. for (auto&& arg : args) {
  183. TensorLayout ilayout = TensorLayout(arg.ishape, dtype::Float32());
  184. TensorLayout olayout = TensorLayout(arg.oshape, dtype::Float32());
  185. DType dtype = dtype::Float32();
  186. checker.set_dtype(0, dtype)
  187. .set_dtype(1, dtype)
  188. .set_dtype(2, dtype)
  189. .set_dtype(3, dtype)
  190. .set_param(arg.param)
  191. .exec(TensorShapeArray{ilayout, olayout, olayout, ilayout});
  192. break;
  193. }
  194. }
  195. //! add update
  196. TEST_F(NAIVE, ADD_UPDATE_RECORD) {
  197. TaskRecordChecker<AddUpdate> checker(2);
  198. param::AddUpdate p{2, -1, 3};
  199. checker.set_param(p)
  200. .set_dtype(0, dtype::BFloat16())
  201. .set_dtype(1, dtype::BFloat16())
  202. .execs({{2, 2, 3}, {2, 2, 3}});
  203. }
  204. //! argxx
  205. TEST_F(NAIVE, ARGXX_RECORD) {
  206. test_argmxx<Argmax>();
  207. test_argmxx<Argmin>();
  208. }
  209. //! argsort
  210. TEST_F(NAIVE, ARGSORT_FORWARD_RECORD) {
  211. run_forward_test(dtype::Float32{});
  212. run_forward_test(dtype::Int32{});
  213. }
  214. TEST_F(NAIVE, ARGSORT_BACKWARD_RECORD) {
  215. run_backward_test(dtype::Float32{});
  216. run_backward_test(dtype::Int32{});
  217. }
  218. TEST_F(NAIVE, BATCH_CONV_BIAS_QS8_RECORD) {
  219. TaskRecordChecker<BatchConvBiasForward> checker(2);
  220. UniformIntRNG const_rng{1, 1};
  221. UniformIntRNG rng{-5, 5};
  222. UniformIntRNG bias_rng{-50, 50};
  223. checker.set_rng(0, &rng)
  224. .set_rng(1, &rng)
  225. .set_rng(2, &rng)
  226. .set_rng(3, &rng)
  227. .set_dtype(0, dtype::QuantizedS8{1.2f})
  228. .set_dtype(1, dtype::QuantizedS8{1.3f})
  229. .set_dtype(2, dtype::QuantizedS32{1.2f * 1.3f})
  230. .set_dtype(3, dtype::QuantizedS8{1.1f})
  231. .set_dtype(4, dtype::QuantizedS8{1.1f})
  232. .set_epsilon(1 + 1e-3);
  233. param::BatchConvBias param;
  234. param.pad_h = 2, param.pad_w = 1;
  235. param.stride_h = 1, param.stride_w = 2;
  236. param.format = param::BatchConvBias::Format::NCHW4;
  237. checker.set_param(param).execs(
  238. {{32, 4, 24, 24, 4}, {32, 32, 4, 1, 1, 4}, {1, 8, 1, 1, 4}, {}, {}});
  239. }
  240. //! batched_matmul
  241. TEST_F(NAIVE, BATCH_MAT_MUL_RECORD) {
  242. TaskRecordChecker<BatchedMatrixMulForward> checker(2);
  243. using TestArg = matrix_mul::TestArg;
  244. //! return expect if stride == -1, stride otherwise
  245. auto stride_val = [](size_t stride, size_t expect) -> size_t {
  246. if (stride == TestArg::UNSET_STRIDE_VAL) {
  247. return expect;
  248. } else {
  249. return stride;
  250. }
  251. };
  252. using Param = MatrixMul::Param;
  253. std::vector<TestArg> args;
  254. args = matrix_mul::get_batched_matmul_args();
  255. for (auto& arg : args) {
  256. if (arg.b == 1) {
  257. continue;
  258. }
  259. size_t m = arg.m, n = arg.n, k = arg.k;
  260. Param param;
  261. param.transposeA = arg.mask & 0x1;
  262. param.transposeB = arg.mask & 0x2;
  263. size_t A0 = m, A1 = k, B0 = k, B1 = n;
  264. TensorShape A, B;
  265. if (param.transposeA) {
  266. std::swap(A0, A1);
  267. }
  268. if (param.transposeB) {
  269. std::swap(B0, B1);
  270. }
  271. ptrdiff_t A_stride = arg.A_stride, B_stride = arg.B_stride,
  272. C_stride = arg.C_stride, A_batch_stride = arg.A_batch_stride,
  273. B_batch_stride = arg.B_batch_stride,
  274. C_batch_stride = arg.C_batch_stride;
  275. A_stride = stride_val(A_stride, A1);
  276. B_stride = stride_val(B_stride, B1);
  277. C_stride = stride_val(C_stride, n);
  278. A_batch_stride = stride_val(A_batch_stride, A0 * A_stride);
  279. B_batch_stride = stride_val(B_batch_stride, B0 * B_stride);
  280. C_batch_stride = stride_val(C_batch_stride, m * C_stride);
  281. checker.set_param(param);
  282. checker.execl(
  283. {TensorLayout{
  284. {arg.b, A0, A1},
  285. {A_batch_stride, A_stride, 1},
  286. dtype::Float32()},
  287. TensorLayout{
  288. {arg.b, B0, B1},
  289. {B_batch_stride, B_stride, 1},
  290. dtype::Float32()},
  291. TensorLayout{
  292. {arg.b, m, n},
  293. {C_batch_stride, C_stride, 1},
  294. dtype::Float32()}});
  295. break;
  296. }
  297. }
  298. //! BN
  299. TEST_F(NAIVE, BN_FORWARD_RECORD) {
  300. TaskRecordChecker<BNForward> checker(2);
  301. checker.set_dtype(0, dtype::Float32())
  302. .set_dtype(1, dtype::Float32())
  303. .set_dtype(2, dtype::Float32())
  304. .set_epsilon(1e-3);
  305. param::BN param;
  306. param.fwd_mode = param::BN::FwdMode::TRAINING;
  307. param.param_dim = param::BN::ParamDim::DIM_1C11;
  308. param.epsilon = 1e-3;
  309. for (size_t n : {1, 2}) {
  310. for (size_t c : {1, 128}) {
  311. for (size_t i : {2, 14}) {
  312. for (float f : {0.5, 1.0}) {
  313. param.avg_factor = f;
  314. checker.set_param(param);
  315. TensorShape src{n, c, i, i};
  316. TensorShape inp{1, c, 1, 1};
  317. checker.execs(
  318. {src, //! src -> input
  319. inp, //! bn_scale -> input
  320. inp, //! bn_bias -> input
  321. inp, //! mean -> output
  322. inp, //! variance -> output
  323. inp, //! batch_mean -> output
  324. inp, //! batch_inv_variance -> output
  325. {}, //! reserve -> output
  326. {}});
  327. }
  328. }
  329. }
  330. }
  331. UniformFloatRNG rng(1.0f, 2.0f);
  332. checker.set_dtype(0, dtype::Float32())
  333. .set_dtype(1, dtype::Float32())
  334. .set_dtype(2, dtype::Float32())
  335. .set_dtype(3, dtype::Float32())
  336. .set_dtype(4, dtype::Float32())
  337. .set_rng(3, &rng)
  338. .set_rng(4, &rng)
  339. .set_epsilon(1e-3);
  340. param.fwd_mode = param::BN::FwdMode::INFERENCE;
  341. param.param_dim = param::BN::ParamDim::DIM_1C11;
  342. param.epsilon = 1e-3;
  343. checker.set_param(param);
  344. for (size_t n : {1, 2}) {
  345. for (size_t c : {1, 128}) {
  346. for (size_t i : {2, 14}) {
  347. TensorShape src{n, c, i, i};
  348. TensorShape inp{1, c, 1, 1};
  349. checker.exec({
  350. src, //! src -> input
  351. inp, //! bn_scale -> input
  352. inp, //! bn_bias -> input
  353. inp, //! mean -> input
  354. inp, //! variance -> input
  355. {}, //! batch_mean -> output[unused]
  356. {}, //! batch_inv_variance -> output[unused]
  357. {}, //! reserve -> output
  358. {} //! dst -> output[shape got by
  359. //! deduced]
  360. });
  361. }
  362. }
  363. }
  364. }
  365. TEST_F(NAIVE, BN_BACKWARD_RECORD) {
  366. TaskRecordChecker<BNBackward> checker(2);
  367. UniformFloatRNG rng(1.0f, 2.0f);
  368. checker.set_dtype(0, dtype::Float32())
  369. .set_dtype(1, dtype::Float32())
  370. .set_dtype(2, dtype::Float32())
  371. .set_dtype(3, dtype::Float32())
  372. .set_dtype(4, dtype::Float32())
  373. .set_rng(3, &rng);
  374. param::BN param;
  375. param.fwd_mode = param::BN::FwdMode::TRAINING;
  376. param.epsilon = 0.0f;
  377. checker.set_param(param);
  378. for (size_t n : {1, 2}) {
  379. for (size_t c : {3, 128}) {
  380. for (size_t i : {2, 14}) {
  381. TensorShape src{n, c, i, i};
  382. TensorShape inp{1, c, 1, 1};
  383. checker.exec({
  384. src, //! x -> input
  385. src, //! dy -> input
  386. inp, //! bn_mean -> input
  387. inp, //! bn_ivar -> input
  388. inp, //! bn_scale -> input
  389. {}, //! reserve -> input
  390. inp, //! d_bn_scale -> output
  391. inp, //! d_bn_bias -> output
  392. src //! dx -> output
  393. });
  394. }
  395. }
  396. }
  397. }
  398. //! concat
  399. TEST_F(NAIVE, CONCAT_RECORD) {
  400. TaskRecordChecker<Concat> checker(2);
  401. using Param = Concat::Param;
  402. for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()})
  403. for (size_t axis = 0; axis < 4; ++axis) {
  404. Param param;
  405. param.axis = axis;
  406. TensorShapeArray shapes(4, TensorShape({2, 3, 4, 5}));
  407. for (size_t i = 0; i < 4; ++i) {
  408. shapes[i].shape[axis] = i + 1;
  409. }
  410. shapes.emplace_back();
  411. for (size_t i = 0; i < shapes.size(); ++i)
  412. checker.set_dtype(i, dtype);
  413. checker.set_param(param).execs(shapes);
  414. }
  415. }
  416. //! ConvBias
  417. TEST_F(NAIVE, CONV_BIAS_RECORD) {
  418. TaskRecordChecker<ConvBias> checker(2);
  419. ConvBias::Param param;
  420. param.format = ConvBias::Param::Format::NCHW;
  421. checker.set_dtype(0, dtype::QuantizedS8(0.1f))
  422. .set_dtype(1, dtype::QuantizedS8(0.2f))
  423. .set_dtype(2, dtype::QuantizedS32(0.02f))
  424. .set_dtype(3, dtype::QuantizedS32(0.3f))
  425. .set_dtype(4, dtype::QuantizedS32(0.02f));
  426. checker.set_param(param).execs(
  427. {{1, 1, 4, 4}, {3, 1, 3, 3}, {1, 3, 1, 1}, {1, 3, 2, 2}, {}});
  428. }
  429. //! Convolution
  430. TEST_F(NAIVE, CONV_RECORD) {
  431. TaskRecordChecker<Convolution> checker(2);
  432. Convolution::Param param;
  433. param.format = Convolution::Param::Format::NCHW;
  434. checker.set_param(param).execs({{1, 1, 4, 4}, {3, 1, 3, 3}, {}});
  435. }
  436. //! Conv3D
  437. TEST_F(NAIVE, CONV3D_RECORD) {
  438. using TestArg = convolution3d::TestArg;
  439. std::vector<TestArg> args = convolution3d::get_args();
  440. TaskRecordChecker<Convolution3DForward> checker(2);
  441. NormalRNG default_rng;
  442. for (auto&& arg : args) {
  443. float scale = 1.0f / sqrt(arg.filter[1] * arg.filter[2] * arg.filter[3] *
  444. arg.filter[4]);
  445. UniformFloatRNG rng(scale, 2 * scale);
  446. checker.set_dtype(0, dtype::Float32())
  447. .set_dtype(1, dtype::Float32())
  448. .set_rng(0, &default_rng)
  449. .set_rng(1, &default_rng)
  450. .set_param(arg.param)
  451. .execs({arg.src, arg.filter, {}});
  452. }
  453. }
  454. //! cumprod
  455. TEST_F(NAIVE, CUMPROD_RECORD) {
  456. TaskRecordChecker<Cumprod> checker(2);
  457. struct TestArg {
  458. param::Cumprod param;
  459. TensorShape shape;
  460. TestArg(param::Cumprod param, TensorShape shape) : param(param), shape(shape) {}
  461. };
  462. std::vector<TestArg> args, args_int32;
  463. for (auto shape : TensorShapeArray{{1000}, {330, 33}, {10, 10, 10}, {5, 5, 5, 5}}) {
  464. for (size_t axis = 0; axis < shape.ndim; ++axis) {
  465. args.emplace_back(param::Cumprod(axis, true, true), shape);
  466. args.emplace_back(param::Cumprod(axis, true, false), shape);
  467. args.emplace_back(param::Cumprod(axis, false, true), shape);
  468. args.emplace_back(param::Cumprod(axis, false, false), shape);
  469. }
  470. }
  471. for (auto shape : TensorShapeArray{{1}, {10}, {100}, {1000}, {10000}}) {
  472. args.emplace_back(param::Cumprod(0, true, true), shape);
  473. args.emplace_back(param::Cumprod(0, true, false), shape);
  474. args.emplace_back(param::Cumprod(0, false, true), shape);
  475. args.emplace_back(param::Cumprod(0, false, false), shape);
  476. }
  477. for (auto shape : TensorShapeArray{{1}, {10}, {100}, {1000}, {10000}}) {
  478. args_int32.emplace_back(param::Cumprod(0, true, true), shape);
  479. args_int32.emplace_back(param::Cumprod(0, true, false), shape);
  480. args_int32.emplace_back(param::Cumprod(0, false, true), shape);
  481. args_int32.emplace_back(param::Cumprod(0, false, false), shape);
  482. }
  483. for (auto arg : args) {
  484. checker.set_param(arg.param);
  485. checker.set_epsilon(1e-2);
  486. checker.set_dtype(0, dtype::Float32()).execs({{arg.shape}, {}});
  487. checker.set_dtype(0, dtype::Int16()).execs({{arg.shape}, {}});
  488. checker.set_dtype(0, dtype::Int32()).execs({{arg.shape}, {}});
  489. }
  490. for (auto arg : args_int32) {
  491. checker.set_param(arg.param);
  492. checker.set_epsilon(1e-2);
  493. checker.set_dtype(0, dtype::Int32()).execs({{arg.shape}, {}});
  494. }
  495. }
  496. //! cumsum
  497. TEST_F(NAIVE, CUMSUM_RECORD) {
  498. TaskRecordChecker<Cumsum> checker(2);
  499. struct TestArg {
  500. param::Cumsum param;
  501. TensorShape shape;
  502. TestArg(param::Cumsum param, TensorShape shape) : param(param), shape(shape) {}
  503. };
  504. std::vector<TestArg> args, args_int32;
  505. for (auto shape : TensorShapeArray{{1000}, {330, 33}, {10, 10, 10}, {5, 5, 5, 5}}) {
  506. for (size_t axis = 0; axis < shape.ndim; ++axis) {
  507. args.emplace_back(param::Cumsum(axis, true, true), shape);
  508. args.emplace_back(param::Cumsum(axis, true, false), shape);
  509. args.emplace_back(param::Cumsum(axis, false, true), shape);
  510. args.emplace_back(param::Cumsum(axis, false, false), shape);
  511. }
  512. }
  513. for (auto shape : TensorShapeArray{{1}, {10}, {100}, {1000}, {10000}}) {
  514. args.emplace_back(param::Cumsum(0, true, true), shape);
  515. args.emplace_back(param::Cumsum(0, true, false), shape);
  516. args.emplace_back(param::Cumsum(0, false, true), shape);
  517. args.emplace_back(param::Cumsum(0, false, false), shape);
  518. }
  519. for (auto shape : TensorShapeArray{{1}, {10}, {100}, {1000}, {10000}}) {
  520. args_int32.emplace_back(param::Cumsum(0, true, true), shape);
  521. args_int32.emplace_back(param::Cumsum(0, true, false), shape);
  522. args_int32.emplace_back(param::Cumsum(0, false, true), shape);
  523. args_int32.emplace_back(param::Cumsum(0, false, false), shape);
  524. }
  525. for (auto arg : args) {
  526. checker.set_param(arg.param);
  527. checker.set_epsilon(1e-2);
  528. checker.set_dtype(0, dtype::Float32()).execs({{arg.shape}, {}});
  529. checker.set_dtype(0, dtype::Int16()).execs({{arg.shape}, {}});
  530. checker.set_dtype(0, dtype::Int32()).execs({{arg.shape}, {}});
  531. }
  532. for (auto arg : args_int32) {
  533. checker.set_param(arg.param);
  534. checker.set_epsilon(1e-2);
  535. checker.set_dtype(0, dtype::Int32()).execs({{arg.shape}, {}});
  536. }
  537. }
  538. //! dct
  539. TEST_F(NAIVE, DCT_RECORD) {
  540. TaskRecordChecker<DctChannelSelectForward> checker(2);
  541. DctChannelSelectForward::Param param;
  542. param.format = DctChannelSelectForward::Param::Format::NCHW4;
  543. checker.set_dtype(0, dtype::Uint8()).set_dtype(3, dtype::QuantizedS8(10.f));
  544. checker.set_param(param).execs({{1, 1, 16, 16}, {}, {}, {}});
  545. }
  546. //! deformable_conv
  547. TEST_F(NAIVE, DEFORMABLE_CONV_FWD_RECORD) {
  548. TaskRecordChecker<DeformableConv> checker(2);
  549. DeformableConv::Param param;
  550. UniformIntRNG im_rng{0, 4};
  551. UniformIntRNG filter_rng{0, 4};
  552. UniformIntRNG offset_rng{-2, 2};
  553. UniformIntRNG mask_rng{0, 1};
  554. checker.set_rng(0, &im_rng)
  555. .set_rng(1, &filter_rng)
  556. .set_rng(2, &offset_rng)
  557. .set_rng(3, &mask_rng);
  558. param.pad_h = 1;
  559. param.pad_w = 1;
  560. param.stride_h = 1;
  561. param.stride_w = 1;
  562. param.dilate_h = 1;
  563. param.dilate_w = 1;
  564. param.format = DeformableConv::Param::Format::NCHW;
  565. param.sparse = DeformableConv::Param::Sparse::GROUP;
  566. checker.set_param(param).execs(
  567. {{1, 2, 5, 5},
  568. {2, 1, 1, 3, 3},
  569. {1, 2 * 2 * 3 * 3, 5, 5},
  570. {1, 2 * 3 * 3, 5, 5},
  571. {}});
  572. checker.set_param(param).execs(
  573. {{1, 2, 5, 5},
  574. {2, 1, 1, 3, 3},
  575. {1, 2 * 2 * 3 * 3, 5, 5},
  576. {1, 2 * 3 * 3, 5, 5},
  577. {}});
  578. param.sparse = DeformableConv::Param::Sparse::DENSE;
  579. checker.set_param(param).execs(
  580. {{1, 2, 5, 5},
  581. {2, 2, 3, 3},
  582. {1, 2 * 2 * 3 * 3, 5, 5},
  583. {1, 2 * 3 * 3, 5, 5},
  584. {}});
  585. }
  586. TEST_F(NAIVE, DEFORMABLE_CONV_BWD_FILTER_RECORD) {
  587. TaskRecordChecker<DeformableConvBackwardFilter> checker(2);
  588. DeformableConv::Param param;
  589. UniformIntRNG im_rng{0, 4};
  590. UniformIntRNG offset_rng{-2, 2};
  591. UniformIntRNG mask_rng{0, 1};
  592. UniformIntRNG out_grad_rng{0, 1};
  593. checker.set_rng(0, &im_rng)
  594. .set_rng(1, &offset_rng)
  595. .set_rng(2, &mask_rng)
  596. .set_rng(3, &out_grad_rng);
  597. param.pad_h = 1;
  598. param.pad_w = 1;
  599. param.stride_h = 1;
  600. param.stride_w = 1;
  601. param.dilate_h = 1;
  602. param.dilate_w = 1;
  603. param.format = DeformableConv::Param::Format::NCHW;
  604. param.sparse = DeformableConv::Param::Sparse::GROUP;
  605. checker.set_param(param).execs(
  606. {{1, 2, 5, 5},
  607. {1, 2 * 2 * 3 * 3, 5, 5},
  608. {1, 2 * 3 * 3, 5, 5},
  609. {1, 2, 5, 5},
  610. {2, 1, 1, 3, 3}});
  611. }
  612. TEST_F(NAIVE, DEFORMABLE_CONV_BWD_DATA_RECORD) {
  613. TaskRecordChecker<DeformableConvBackwardData> checker(2);
  614. DeformableConv::Param param;
  615. ConstValue im_rng{1};
  616. ConstValue filter_rng{0.99};
  617. ConstValue offset_rng{1.1};
  618. ConstValue mask_rng{1};
  619. ConstValue out_grad_rng{1};
  620. checker.set_rng(0, &im_rng)
  621. .set_rng(1, &filter_rng)
  622. .set_rng(2, &offset_rng)
  623. .set_rng(3, &mask_rng)
  624. .set_rng(4, &out_grad_rng);
  625. param.pad_h = 1;
  626. param.pad_w = 1;
  627. param.stride_h = 1;
  628. param.stride_w = 1;
  629. param.dilate_h = 1;
  630. param.dilate_w = 1;
  631. param.format = DeformableConv::Param::Format::NCHW;
  632. param.sparse = DeformableConv::Param::Sparse::GROUP;
  633. checker.set_param(param).execs(
  634. {{1, 2, 5, 5},
  635. {2, 1, 1, 3, 3},
  636. {1, 1 * 2 * 3 * 3, 5, 5},
  637. {1, 1 * 3 * 3, 5, 5},
  638. {1, 2, 5, 5},
  639. {1, 2, 5, 5},
  640. {1, 1 * 2 * 3 * 3, 5, 5},
  641. {1, 1 * 3 * 3, 5, 5}});
  642. }
  643. //! elemwise
  644. TEST_F(NAIVE, ELEMWISE_COMMON_RECORD) {
  645. TaskRecordChecker<ElemwiseForward> checker(2);
  646. using Mode = ElemwiseForward::Param::Mode;
  647. auto run_activate = [&](size_t N, size_t C, size_t H, size_t W, Mode mode,
  648. DType dtype) {
  649. checker.set_param(mode).set_dtype(0, dtype).set_dtype(1, dtype);
  650. checker.execs({{N, C, H, W}, {}});
  651. };
  652. auto run_binary = [&](size_t N, size_t C, size_t H, size_t W, Mode mode,
  653. DType dtype) {
  654. checker.set_param(mode).set_dtype(0, dtype).set_dtype(1, dtype).set_dtype(
  655. 2, dtype);
  656. checker.execs({{N, C, H, W}, {N, C, H, W}, {}});
  657. };
  658. auto run_unary = [&](size_t N, size_t C, size_t H, size_t W, Mode mode,
  659. DType dtype) {
  660. checker.set_param(mode).set_dtype(0, dtype).set_dtype(1, dtype);
  661. checker.execs({{N, C, H, W}, {}});
  662. };
  663. #define RUN_ACTIVATE(_dt) \
  664. run_activate(4, 32, 10, 10, Mode::RELU, _dt); \
  665. run_activate(4, 32, 10, 10, Mode::SIGMOID, _dt);
  666. RUN_ACTIVATE(dtype::Float32());
  667. RUN_ACTIVATE(dtype::Float16());
  668. checker.set_epsilon(1e-2);
  669. RUN_ACTIVATE(dtype::BFloat16());
  670. #undef RUN_ACTIVATE
  671. checker.set_epsilon(1e-3);
  672. #define RUN_BINARY(_dt) \
  673. run_binary(4, 32, 10, 10, Mode::ADD, _dt); \
  674. run_binary(4, 32, 10, 10, Mode::SUB, _dt); \
  675. run_binary(4, 32, 10, 10, Mode::MUL, _dt); \
  676. run_binary(4, 32, 10, 10, Mode::MIN, _dt); \
  677. run_binary(4, 32, 10, 10, Mode::MAX, _dt);
  678. RUN_BINARY(dtype::Float32());
  679. RUN_BINARY(dtype::Float16());
  680. RUN_BINARY(dtype::BFloat16());
  681. RUN_BINARY(dtype::Int32());
  682. RUN_BINARY(dtype::Int16());
  683. //! true_div
  684. run_binary(4, 32, 10, 10, Mode::TRUE_DIV, dtype::Float32());
  685. RUN_BINARY(dtype::Float16());
  686. checker.set_epsilon(1e-2);
  687. run_binary(4, 32, 10, 10, Mode::TRUE_DIV, dtype::Float16());
  688. RUN_BINARY(dtype::BFloat16());
  689. //! FIXME: precision is especially low
  690. checker.set_epsilon(1e-1);
  691. run_binary(4, 32, 10, 10, Mode::TRUE_DIV, dtype::BFloat16());
  692. #undef RUN_BINARY
  693. #define RUN_UNARY(_dt) \
  694. run_unary(4, 32, 10, 10, Mode::ABS, _dt); \
  695. run_unary(4, 32, 10, 10, Mode::SIN, _dt); \
  696. run_unary(4, 32, 10, 10, Mode::COS, _dt); \
  697. run_unary(4, 32, 10, 10, Mode::EXP, _dt); \
  698. run_unary(4, 32, 10, 10, Mode::CEIL, _dt); \
  699. run_unary(4, 32, 10, 10, Mode::TANH, _dt);
  700. RUN_UNARY(dtype::Float32());
  701. RUN_UNARY(dtype::BFloat16());
  702. checker.set_epsilon(1e-2);
  703. RUN_UNARY(dtype::Float16());
  704. //! FLOOR
  705. run_unary(4, 32, 10, 10, Mode::FLOOR, dtype::Float32());
  706. run_unary(4, 32, 10, 10, Mode::FLOOR, dtype::Float16());
  707. //! INT TEST
  708. run_unary(4, 32, 10, 10, Mode::ABS, dtype::Int16());
  709. run_unary(4, 32, 10, 10, Mode::ABS, dtype::Int32());
  710. #undef RUN_UNARY
  711. //! naive impl
  712. run_binary(4, 32, 10, 10, Mode::LT, dtype::Float32());
  713. run_binary(4, 32, 10, 10, Mode::LT, dtype::Int32());
  714. run_binary(4, 32, 10, 10, Mode::LEQ, dtype::Float32());
  715. run_binary(4, 32, 10, 10, Mode::LEQ, dtype::Int32());
  716. run_binary(4, 32, 10, 10, Mode::EQ, dtype::Float32());
  717. run_binary(4, 32, 10, 10, Mode::EQ, dtype::Int32());
  718. auto rng = UniformFloatRNG(0.01, 2.0);
  719. checker.set_rng(0, &rng);
  720. run_unary(4, 32, 10, 10, Mode::LOG, dtype::Float32());
  721. run_unary(4, 32, 10, 10, Mode::LOG, dtype::BFloat16());
  722. checker.set_epsilon(1e-2);
  723. run_unary(4, 32, 10, 10, Mode::LOG, dtype::Float16());
  724. run_unary(4, 32, 10, 10, Mode::NEGATE, dtype::Float32());
  725. run_unary(4, 32, 10, 10, Mode::NEGATE, dtype::BFloat16());
  726. run_unary(4, 32, 10, 10, Mode::NEGATE, dtype::Float16());
  727. auto rng_int = UniformIntNonZeroRNG(1, 65535);
  728. checker.set_rng(0, &rng_int);
  729. run_unary(4, 32, 10, 10, Mode::NEGATE, dtype::Int32());
  730. run_unary(4, 32, 10, 10, Mode::NEGATE, dtype::Int16());
  731. }
  732. TEST_F(NAIVE, ELEMWISE_BROADCAST_RECORD) {
  733. TaskRecordChecker<ElemwiseForward> checker(2);
  734. using Mode = ElemwiseForward::Param::Mode;
  735. //! do broadcast test
  736. auto run_binary_broadcast = [&](size_t N, size_t C, size_t H, size_t W, Mode mode,
  737. DType dtype) {
  738. checker.set_param(mode).set_dtype(0, dtype).set_dtype(1, dtype);
  739. checker.execs({{N, C, H, W}, {N, C, 1, 1}, {}});
  740. checker.execs({{N, C, 1, 1}, {N, C, H, W}, {}});
  741. checker.execs({{N, C, H, W}, {1}, {}});
  742. checker.execs({{1}, {N, C, H, W}, {}});
  743. checker.execs({{N, C, H, W}, {1, C, H, W}, {}});
  744. checker.execs({{1, C, H, W}, {N, C, H, W}, {}});
  745. };
  746. #define RUN_BINARY(_dt) \
  747. run_binary_broadcast(4, 32, 10, 10, Mode::ADD, _dt); \
  748. run_binary_broadcast(4, 32, 10, 10, Mode::SUB, _dt); \
  749. run_binary_broadcast(4, 32, 10, 10, Mode::MUL, _dt); \
  750. run_binary_broadcast(4, 32, 10, 10, Mode::MIN, _dt); \
  751. run_binary_broadcast(4, 32, 10, 10, Mode::MAX, _dt);
  752. RUN_BINARY(dtype::Float32());
  753. run_binary_broadcast(4, 32, 10, 10, Mode::TRUE_DIV, dtype::Float32());
  754. RUN_BINARY(dtype::Float16());
  755. checker.set_epsilon(1e-2);
  756. run_binary_broadcast(4, 32, 10, 10, Mode::TRUE_DIV, dtype::Float16());
  757. RUN_BINARY(dtype::BFloat16());
  758. //! FIXME: precision is especially low
  759. checker.set_epsilon(1e-1);
  760. run_binary_broadcast(4, 32, 10, 10, Mode::TRUE_DIV, dtype::BFloat16());
  761. RUN_BINARY(dtype::Int16());
  762. RUN_BINARY(dtype::Int32());
  763. #undef RUN_BINARY
  764. }
  765. TEST_F(NAIVE, ELEMWISE_FUSE_MUL_ADD3_RECORD) {
  766. TaskRecordChecker<ElemwiseForward> checker(2);
  767. using Mode = ElemwiseForward::Param::Mode;
  768. auto run_mul_add = [&](size_t N, size_t C, size_t H, size_t W, DType dtype) {
  769. checker.set_param(Mode::FUSE_MUL_ADD3)
  770. .set_dtype(0, dtype)
  771. .set_dtype(1, dtype)
  772. .set_dtype(2, dtype);
  773. checker.execs({{1}, {N, C, H, W}, {1}, {}});
  774. checker.execs({{N, C, 1, 1}, {N, C, H, W}, {1}, {}});
  775. checker.execs({{N, C, H, W}, {N, C, H, W}, {1}, {}});
  776. checker.execs({{N, C, 1, 1}, {N, C, H, W}, {N, C, 1, 1}, {}});
  777. };
  778. run_mul_add(4, 32, 10, 10, dtype::Float32());
  779. checker.set_epsilon(1e-2);
  780. run_mul_add(4, 32, 10, 10, dtype::Float16());
  781. //! FIXME: precision is especially low
  782. checker.set_epsilon(1e-1);
  783. run_mul_add(4, 32, 10, 10, dtype::BFloat16());
  784. run_mul_add(4, 32, 10, 10, dtype::Int16());
  785. run_mul_add(4, 32, 10, 10, dtype::Int32());
  786. }
  787. TEST_F(NAIVE, ELEMWISE_FUSE_MUL_ADD4_RECORD) {
  788. TaskRecordChecker<ElemwiseForward> checker(2);
  789. using Mode = ElemwiseForward::Param::Mode;
  790. auto run_mul_add = [&](size_t N, size_t C, size_t H, size_t W, DType dtype) {
  791. checker.set_param(Mode::FUSE_MUL_ADD4)
  792. .set_dtype(0, dtype)
  793. .set_dtype(1, dtype)
  794. .set_dtype(2, dtype)
  795. .set_dtype(3, dtype)
  796. .set_dtype(4, dtype);
  797. checker.execs({{1}, {N, C, H, W}, {1}, {N, C, H, W}, {}});
  798. checker.execs({{1}, {N, C, H, W}, {N, C, H, W}, {1}, {}});
  799. checker.execs({{N, C, 1, 1}, {N, C, H, W}, {N, C, 1, 1}, {N, C, H, W}, {}});
  800. checker.execs({{N, C, H, W}, {N, C, H, W}, {N, C, H, W}, {N, C, H, W}, {}});
  801. };
  802. run_mul_add(4, 32, 10, 10, dtype::Float32());
  803. checker.set_epsilon(1e-2);
  804. run_mul_add(4, 32, 10, 10, dtype::Float16());
  805. //! FIXME: precision is especially low
  806. checker.set_epsilon(1e-1);
  807. run_mul_add(4, 32, 10, 10, dtype::BFloat16());
  808. run_mul_add(4, 32, 10, 10, dtype::Int16());
  809. run_mul_add(4, 32, 10, 10, dtype::Int32());
  810. }
  811. TEST_F(NAIVE, ELEMWISE_FUSE_ADD_RELU_RECORD) {
  812. TaskRecordChecker<ElemwiseForward> checker(2);
  813. using Mode = ElemwiseForward::Param::Mode;
  814. auto run_mul_add = [&](size_t N, size_t C, size_t H, size_t W, DType dtype) {
  815. checker.set_param(Mode::FUSE_ADD_RELU)
  816. .set_dtype(0, dtype)
  817. .set_dtype(1, dtype)
  818. .set_dtype(2, dtype);
  819. checker.execs({{N, C, H, W}, {N, C, H, W}, {}});
  820. };
  821. run_mul_add(4, 32, 10, 10, dtype::Float32());
  822. checker.set_epsilon(1e-2);
  823. run_mul_add(4, 32, 10, 10, dtype::Float16());
  824. //! FIXME: precision is especially low
  825. checker.set_epsilon(1e-1);
  826. run_mul_add(4, 32, 10, 10, dtype::BFloat16());
  827. }
  828. TEST_F(NAIVE, ELEMWISE_FUSE_ADD_SIGMOID_RECORD) {
  829. TaskRecordChecker<ElemwiseForward> checker(2);
  830. using Mode = ElemwiseForward::Param::Mode;
  831. auto run_mul_add = [&](size_t N, size_t C, size_t H, size_t W, DType dtype) {
  832. checker.set_param(Mode::FUSE_ADD_SIGMOID)
  833. .set_dtype(0, dtype)
  834. .set_dtype(1, dtype)
  835. .set_dtype(2, dtype);
  836. checker.execs({{N, C, H, W}, {N, C, H, W}, {}});
  837. };
  838. run_mul_add(4, 32, 10, 10, dtype::Float32());
  839. checker.set_epsilon(1e-2);
  840. run_mul_add(4, 32, 10, 10, dtype::Float16());
  841. //! FIXME: precision is especially low
  842. checker.set_epsilon(1e-1);
  843. run_mul_add(4, 32, 10, 10, dtype::BFloat16());
  844. }
  845. TEST_F(NAIVE, ELEMWISE_FUSE_ADD_TANH_RECORD) {
  846. TaskRecordChecker<ElemwiseForward> checker(2);
  847. using Mode = ElemwiseForward::Param::Mode;
  848. auto run_mul_add = [&](size_t N, size_t C, size_t H, size_t W, DType dtype) {
  849. checker.set_param(Mode::FUSE_ADD_TANH)
  850. .set_dtype(0, dtype)
  851. .set_dtype(1, dtype)
  852. .set_dtype(2, dtype);
  853. checker.execs({{N, C, H, W}, {N, C, H, W}, {}});
  854. };
  855. run_mul_add(4, 32, 10, 10, dtype::Float32());
  856. checker.set_epsilon(1e-2);
  857. run_mul_add(4, 32, 10, 10, dtype::Float16());
  858. //! FIXME: precision is especially low
  859. checker.set_epsilon(1e-1);
  860. run_mul_add(4, 32, 10, 10, dtype::BFloat16());
  861. }
  862. TEST_F(NAIVE, ELEMWISE_VECTOR_RECORD) {
  863. TaskRecordChecker<ElemwiseForward> checker(2);
  864. using Mode = ElemwiseForward::Param::Mode;
  865. auto run_vector = [&](size_t N, DType dtype, Mode mode) {
  866. checker.set_param(mode).set_dtype(0, dtype).set_dtype(1, dtype).set_dtype(
  867. 2, dtype);
  868. checker.execs({{N}, {1, N}, {}});
  869. checker.execs({{1, N}, {N}, {}});
  870. checker.execs({{N}, {1}, {}});
  871. checker.execs({{1}, {N}, {}});
  872. checker.execs({{1}, {1, 1}, {}});
  873. checker.execs({{1, 1, 1}, {1}, {}});
  874. };
  875. run_vector(1000, dtype::Float32(), Mode::ADD);
  876. run_vector(1000, dtype::Float32(), Mode::MUL);
  877. checker.set_epsilon(1e-2);
  878. run_vector(1000, dtype::Float16(), Mode::ADD);
  879. run_vector(1000, dtype::Float16(), Mode::MUL);
  880. //! FIXME: precision is especially low
  881. checker.set_epsilon(1e-1);
  882. run_vector(1000, dtype::BFloat16(), Mode::ADD);
  883. run_vector(1000, dtype::BFloat16(), Mode::MUL);
  884. }
  885. //! EYE
  886. TEST_F(NAIVE, EYE_RECORD) {
  887. TaskRecordChecker<Eye> checker(2);
  888. for (DType dtype :
  889. std::vector<DType>{dtype::Float16(), dtype::Int32(), dtype::Float32()})
  890. for (int k = -20; k < 20; ++k) {
  891. checker.set_param({k, dtype.enumv()});
  892. checker.set_dtype(0, dtype);
  893. checker.execs(TensorShapeArray{{3, 4}});
  894. checker.execs(TensorShapeArray{{4, 3}});
  895. }
  896. }
  897. //! FILL
  898. TEST_F(NAIVE, FILL_RECORD) {
  899. TaskRecordChecker<Fill> checker(2);
  900. for (DType dtype :
  901. std::vector<DType>{dtype::Float16(), dtype::Int32(), dtype::Float32()})
  902. for (float value : {-1.23, 0.0, 0.001, 234.0, 2021.072}) {
  903. checker.set_param({value});
  904. checker.set_dtype(0, dtype);
  905. checker.exec(TensorShapeArray{{1, 1}});
  906. checker.exec(TensorShapeArray{{2, 3, 4}});
  907. }
  908. }
  909. //! LINSPACE
  910. TEST_F(NAIVE, LINSPACE_RECORD) {
  911. TaskRecordChecker<Linspace> checker(2);
  912. Linspace::Param param;
  913. param.start = 0.5;
  914. param.stop = 1.5;
  915. param.endpoint = true;
  916. for (DType dtype :
  917. std::vector<DType>{dtype::Float16(), dtype::Int32(), dtype::Float32()}) {
  918. checker.set_dtype(0, dtype).set_param(param).exec(TensorShapeArray{{11}});
  919. }
  920. param.endpoint = false;
  921. for (DType dtype :
  922. std::vector<DType>{dtype::Float16(), dtype::Int32(), dtype::Float32()}) {
  923. checker.set_dtype(0, dtype).set_param(param).exec(TensorShapeArray{{11}});
  924. }
  925. }
  926. //! LOCAL
  927. TEST_F(NAIVE, LOCAL_FORWARD_RECORD) {
  928. auto args = local::get_args_for_cuda();
  929. for (size_t i = 0; i < 2; ++i) {
  930. auto&& arg = args[i];
  931. TaskRecordChecker<LocalForward> checker(2);
  932. checker.set_param(arg.param).exec(
  933. TensorShapeArray{arg.sshape(), arg.fshape(), arg.dshape()});
  934. }
  935. }
  936. TEST_F(NAIVE, LOCAL_BACKWARD_DATA_RECORD) {
  937. using namespace local;
  938. auto args = local::get_args_bwd_data_for_cuda();
  939. for (size_t i = 0; i < 2; ++i) {
  940. auto&& arg = args[i];
  941. TaskRecordChecker<LocalBackwardData> checker(2);
  942. checker.set_param(arg.param).exec(
  943. TensorShapeArray{arg.fshape(), arg.dshape(), arg.sshape()});
  944. }
  945. }
  946. TEST_F(NAIVE, LOCAL_BACKWARD_FILTER_RECORD) {
  947. using namespace local;
  948. auto args = local::get_args_bwd_filter_for_cuda();
  949. for (size_t i = 0; i < 2; ++i) {
  950. auto&& arg = args[i];
  951. TaskRecordChecker<LocalBackwardFilter> checker(2);
  952. checker.set_param(arg.param).exec(
  953. TensorShapeArray{arg.sshape(), arg.dshape(), arg.fshape()});
  954. }
  955. }
  956. //! matrix inverse
  957. TEST_F(NAIVE, MATRIX_INVERSE_RECORD) {
  958. TaskRecordChecker<MatrixInverse> checker(2);
  959. checker.exec({{10, 20, 20}, {}});
  960. }
  961. //! matmul
  962. TEST_F(NAIVE, MATRIX_MUL_RECORD) {
  963. TaskRecordChecker<MatrixMul> checker(2);
  964. MatrixMul::Param param;
  965. param.transposeA = false;
  966. param.transposeB = false;
  967. checker.set_dtype(0, dtype::Quantized8Asymm(0.1f, (uint8_t)128))
  968. .set_dtype(1, dtype::Quantized8Asymm(0.2f, (uint8_t)233))
  969. .set_dtype(2, dtype::QuantizedS32(0.1f * 0.2f));
  970. checker.set_param(param).exec({{4, 7}, {7, 5}, {}});
  971. param.transposeA = true;
  972. checker.set_dtype(0, dtype::Quantized8Asymm(0.7f, (uint8_t)128))
  973. .set_dtype(1, dtype::Quantized8Asymm(0.4f, (uint8_t)128))
  974. .set_dtype(2, dtype::QuantizedS32(0.7f * 0.4f));
  975. checker.set_param(param).exec({{2, 1}, {2, 1}, {}});
  976. }
  977. //! pooling
  978. TEST_F(NAIVE, POOLING_QUANTIZED_RECORD) {
  979. using Mode = Pooling::Param::Mode;
  980. TaskRecordChecker<Pooling> checker(2);
  981. Pooling::Param param{Mode::MAX, 1, 1, 2, 2, 2, 2};
  982. auto dt = dtype::Quantized8Asymm(0.1f, (uint8_t)128);
  983. checker.set_dtype(0, dt).set_dtype(1, dt);
  984. checker.set_param(param).exec({{1, 1, 3, 3}, {}});
  985. param = {Mode::AVERAGE, 1, 1, 2, 2, 2, 2};
  986. checker.set_param(param).exec({{1, 1, 3, 3}, {}});
  987. param = {Mode::AVERAGE_COUNT_EXCLUDE_PADDING, 1, 1, 2, 2, 2, 2};
  988. checker.set_param(param).exec({{1, 1, 3, 3}, {}});
  989. auto dt32 = dtype::QuantizedS32(0.233f);
  990. checker.set_dtype(0, dt32).set_dtype(1, dt32);
  991. param = {Mode::MAX, 1, 1, 2, 2, 2, 2};
  992. checker.set_param(param).exec({{1, 1, 3, 3}, {}});
  993. }
  994. TEST_F(NAIVE, REDUCE_QUANTIZED_RECORD) {
  995. using Mode = Reduce::Param::Mode;
  996. TaskRecordChecker<Reduce> checker(2);
  997. Reduce::Param param;
  998. param.mode = Mode::SUM;
  999. param.data_type = param::Reduce::DataType::QUINT_I8xO32;
  1000. param.axis = 0;
  1001. checker.set_dtype(0, dtype::Quantized8Asymm(0.1f, (uint8_t)128))
  1002. .set_dtype(1, dtype::QuantizedS32(0.1f));
  1003. checker.set_param(param).exec({{3, 4}, {}});
  1004. param.data_type = param::Reduce::DataType::DEFAULT;
  1005. param.mode = Mode::MEAN;
  1006. checker.set_dtype(0, dtype::Quantized8Asymm(1.f, (uint8_t)128))
  1007. .set_dtype(1, dtype::Quantized8Asymm(1.f, (uint8_t)128));
  1008. checker.set_param(param).exec({{3, 4}, {}});
  1009. checker.set_dtype(0, dtype::Quantized8Asymm(0.00233f, (uint8_t)128))
  1010. .set_dtype(1, dtype::Quantized8Asymm(0.00233f, (uint8_t)128));
  1011. checker.set_param(param).exec({{3, 4}, {}});
  1012. checker.set_dtype(0, dtype::Quantized8Asymm(7e-10f, (uint8_t)45))
  1013. .set_dtype(1, dtype::Quantized8Asymm(7e-10f, (uint8_t)45));
  1014. checker.set_param(param).exec({{3, 4}, {}});
  1015. }
  1016. //! relayout format
  1017. TEST_F(NAIVE, RELAYOUT_FORMAT_NCHW4_NCHW_RECORD) {
  1018. TaskRecordChecker<RelayoutFormat> checker(2);
  1019. RelayoutFormat::Param param{RelayoutFormat::Param::Mode::NCHW4_NCHW};
  1020. checker.set_param(param).exec({{1, 2, 1, 2, 4}, {}});
  1021. param.oc = 7;
  1022. checker.set_param(param).exec({{1, 2, 1, 2, 4}, {}});
  1023. param.oc = 6;
  1024. param.group = 2;
  1025. checker.set_param(param).exec({{1, 2, 1, 2, 4}, {}});
  1026. }
  1027. TEST_F(NAIVE, RELAYOUT_FORMAT_NCHW_NCHW4_WEIGHT_RECORD) {
  1028. TaskRecordChecker<RelayoutFormat> checker(2);
  1029. RelayoutFormat::Param param{RelayoutFormat::Param::Mode::NCHW_NCHW4_WEIGHT};
  1030. checker.set_param(param);
  1031. checker.exec({{2, 2, 2, 2}, {}});
  1032. checker.exec({{2, 2, 1, 2, 2}, {}});
  1033. }
  1034. TEST_F(NAIVE, RELAYOUT_FORMAT_NCHW_NCHW4_RECORD) {
  1035. TaskRecordChecker<RelayoutFormat> checker(2);
  1036. RelayoutFormat::Param param{RelayoutFormat::Param::Mode::NCHW_NCHW4};
  1037. checker.set_param(param).exec({{1, 8, 1, 2}, {}});
  1038. param.group = 4;
  1039. checker.set_param(param).exec({{1, 8, 1, 2}, {}});
  1040. param.group = 2;
  1041. checker.set_param(param).exec({{1, 6, 1, 2}, {}});
  1042. }
  1043. TEST_F(NAIVE, RELAYOUT_FORMAT_NCHW88_RECORD) {
  1044. TaskRecordChecker<RelayoutFormat> checker(2);
  1045. {
  1046. RelayoutFormat::Param param{RelayoutFormat::Param::Mode::NCHW_NCHW88};
  1047. checker.set_param(param);
  1048. checker.exec({{1, 8, 1, 2}, {}});
  1049. checker.exec({{2, 8, 1, 2}, {}});
  1050. checker.exec({{2, 4, 1, 2}, {}});
  1051. checker.exec({{1, 3, 64, 64}, {}});
  1052. }
  1053. {
  1054. RelayoutFormat::Param param{RelayoutFormat::Param::Mode::NCHW88_NCHW};
  1055. checker.set_param(param).exec({{1, 1, 1, 2, 8}, {}});
  1056. }
  1057. }
  1058. TEST_F(NAIVE, RELAYOUT_FORMAT_NCHW88_DENSE_RECORD) {
  1059. TaskRecordChecker<RelayoutFormat> checker(2);
  1060. RelayoutFormat::Param param{
  1061. RelayoutFormat::Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT};
  1062. checker.set_param(param);
  1063. checker.exec({{8, 8, 1, 1}, {}});
  1064. checker.exec({{8, 2, 1, 1}, {}});
  1065. }
  1066. TEST_F(NAIVE, RELAYOUT_FORMAT_NCHW88_CHAIN_RECORD) {
  1067. TaskRecordChecker<RelayoutFormat> checker(2);
  1068. RelayoutFormat::Param param{
  1069. RelayoutFormat::Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT};
  1070. checker.set_param(param);
  1071. checker.exec({{8, 1, 1, 1, 2}, {}});
  1072. checker.exec({{2, 1, 1, 1, 2}, {}});
  1073. }
  1074. TEST_F(NAIVE, RELAYOUT_FORMAT_NCHW88_GROUP_RECORD) {
  1075. TaskRecordChecker<RelayoutFormat> checker(2);
  1076. {
  1077. RelayoutFormat::Param param{
  1078. RelayoutFormat::Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT};
  1079. checker.set_param(param);
  1080. checker.exec({{1, 8, 8, 1, 1}, {}});
  1081. checker.exec({{1, 8, 2, 1, 1}, {}});
  1082. }
  1083. {
  1084. RelayoutFormat::Param param{RelayoutFormat::Param::Mode::NCHW88_NCHW};
  1085. checker.set_param(param).exec({TensorShape{1, 8, 64, 64, 8}, {}});
  1086. }
  1087. }
  1088. //! separable conv
  1089. TEST_F(NAIVE, SEPARABLE_CONV_RECORD) {
  1090. using TestArg = megdnn::test::separable_conv::TestArg;
  1091. std::vector<TestArg> args = separable_conv::get_args();
  1092. TaskRecordChecker<SeparableConvForward> checker(2);
  1093. for (auto&& arg : args) {
  1094. checker.set_param(arg.param).execs({arg.src, arg.filter_x, arg.filter_y, {}});
  1095. }
  1096. }
  1097. //! warp affine
  1098. TEST_F(NAIVE, WARP_AFFINE_RECORD) {
  1099. TaskRecordChecker<WarpAffine> checker(2);
  1100. WarpAffine::Param param;
  1101. param.border_mode = WarpAffine::Param::BorderMode::BORDER_REFLECT;
  1102. param.imode = WarpAffine::Param::InterpolationMode::LINEAR;
  1103. param.format = WarpAffine::Param::Format::NCHW;
  1104. checker.set_dtype(0, dtype::Uint8{})
  1105. .set_dtype(1, dtype::Float32{})
  1106. .set_dtype(2, dtype::Uint8{});
  1107. checker.set_param(param).exec({{1, 1, 3, 3}, {1, 2, 3}, {1, 1, 2, 2}});
  1108. checker.set_dtype(0, dtype::Quantized8Asymm{1.4f, static_cast<uint8_t>(127)})
  1109. .set_dtype(1, dtype::Float32{})
  1110. .set_dtype(2, dtype::Quantized8Asymm{1.4f, static_cast<uint8_t>(127)});
  1111. checker.set_param(param).exec({{1, 1, 3, 3}, {1, 2, 3}, {1, 1, 2, 2}});
  1112. }
  1113. TEST_F(NAIVE, WARP_AFFINE_CV_RECORD) {
  1114. using TestArg = warp_affine::TestArg;
  1115. std::vector<TestArg> args = warp_affine::get_cv_args();
  1116. TaskRecordChecker<WarpAffine> checker(2);
  1117. for (auto&& arg : args) {
  1118. checker.set_param(arg.param)
  1119. .set_dtype(0, dtype::Uint8())
  1120. .set_dtype(1, dtype::Float32())
  1121. .set_dtype(2, dtype::Uint8())
  1122. .execs({arg.src, arg.trans, arg.dst});
  1123. }
  1124. for (auto&& arg : args) {
  1125. checker.set_param(arg.param)
  1126. .set_dtype(0, dtype::Float32())
  1127. .set_dtype(1, dtype::Float32())
  1128. .set_dtype(2, dtype::Float32())
  1129. .execs({arg.src, arg.trans, arg.dst});
  1130. }
  1131. }
  1132. //! warp perspective
  1133. TEST_F(NAIVE, WARP_PERSPECTIVE_RECORD) {
  1134. TaskRecordChecker<WarpPerspective> checker(2);
  1135. WarpPerspective::Param param;
  1136. param.bmode = WarpPerspective::Param::BorderMode::BORDER_REFLECT;
  1137. param.imode = WarpPerspective::Param::InterpolationMode::LINEAR;
  1138. param.format = WarpPerspective::Param::Format::NCHW;
  1139. checker.set_dtype(0, dtype::Uint8{})
  1140. .set_dtype(1, dtype::Float32{})
  1141. .set_dtype(2, dtype::Uint8{});
  1142. checker.set_param(param).exec({{1, 1, 3, 3}, {1, 3, 3}, {1, 1, 2, 2}});
  1143. checker.set_dtype(0, dtype::Quantized8Asymm{1.4f, static_cast<uint8_t>(127)})
  1144. .set_dtype(1, dtype::Float32{})
  1145. .set_dtype(2, dtype::Quantized8Asymm{1.4f, static_cast<uint8_t>(127)});
  1146. checker.set_param(param).exec({{1, 1, 3, 3}, {1, 3, 3}, {1, 1, 2, 2}});
  1147. }
  1148. TEST_F(NAIVE, WARP_PERSPECTIVE_NCHW4_RECORD) {
  1149. using Param = WarpPerspective::Param;
  1150. WarpPerspective::Param param;
  1151. TaskRecordChecker<WarpPerspectiveForward> checker(2);
  1152. WarpPerspectiveMatRNG rng;
  1153. checker.set_rng(1, &rng);
  1154. checker.set_dtype(0, dtype::QuantizedS8(0.1f));
  1155. checker.set_dtype(2, dtype::QuantizedS8(0.1f));
  1156. for (auto bmode :
  1157. {WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT,
  1158. WarpPerspective::BorderMode::REPLICATE,
  1159. WarpPerspective::BorderMode::CONSTANT}) {
  1160. param.border_val = 0.3f;
  1161. param.bmode = bmode;
  1162. param.imode = Param::InterpolationMode::LINEAR;
  1163. param.format = Param::Format::NCHW4;
  1164. checker.set_param(param);
  1165. checker.execs({{2, 1, 10, 11, 4}, {2, 3, 3}, {2, 1, 11, 12, 4}});
  1166. checker.execs({{1, 25, 25, 25, 4}, {1, 3, 3}, {1, 25, 25, 510, 4}});
  1167. checker.execs({{1, 25, 25, 25, 4}, {1, 3, 3}, {1, 25, 51, 51, 4}});
  1168. checker.execs({{1, 25, 51, 51, 4}, {1, 3, 3}, {1, 25, 25, 25, 4}});
  1169. }
  1170. }
  1171. TEST_F(NAIVE_MULTI_THREADS, WARP_PERSPECTIVE_RECORD) {
  1172. TaskRecordChecker<WarpPerspective> checker(2);
  1173. WarpPerspective::Param param;
  1174. param.bmode = WarpPerspective::Param::BorderMode::BORDER_REFLECT;
  1175. param.imode = WarpPerspective::Param::InterpolationMode::LINEAR;
  1176. param.format = WarpPerspective::Param::Format::NCHW;
  1177. checker.set_dtype(0, dtype::Uint8{})
  1178. .set_dtype(1, dtype::Float32{})
  1179. .set_dtype(2, dtype::Uint8{});
  1180. checker.set_param(param).exec({{1, 1, 3, 3}, {1, 3, 3}, {1, 1, 2, 2}});
  1181. checker.set_dtype(0, dtype::Quantized8Asymm{1.4f, static_cast<uint8_t>(127)})
  1182. .set_dtype(1, dtype::Float32{})
  1183. .set_dtype(2, dtype::Quantized8Asymm{1.4f, static_cast<uint8_t>(127)});
  1184. checker.set_param(param).exec({{1, 1, 3, 3}, {1, 3, 3}, {1, 1, 2, 2}});
  1185. }
  1186. TEST_F(NAIVE_MULTI_THREADS, WARP_PERSPECTIVE_NCHW4_RECORD) {
  1187. using Param = WarpPerspective::Param;
  1188. WarpPerspective::Param param;
  1189. TaskRecordChecker<WarpPerspectiveForward> checker(2);
  1190. WarpPerspectiveMatRNG rng;
  1191. checker.set_rng(1, &rng);
  1192. checker.set_dtype(0, dtype::QuantizedS8(0.1f));
  1193. checker.set_dtype(2, dtype::QuantizedS8(0.1f));
  1194. for (auto bmode :
  1195. {WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT,
  1196. WarpPerspective::BorderMode::REPLICATE,
  1197. WarpPerspective::BorderMode::CONSTANT}) {
  1198. param.border_val = 0.3f;
  1199. param.bmode = bmode;
  1200. param.imode = Param::InterpolationMode::LINEAR;
  1201. param.format = Param::Format::NCHW4;
  1202. checker.set_param(param);
  1203. checker.execs({{2, 1, 10, 11, 4}, {2, 3, 3}, {2, 1, 11, 12, 4}});
  1204. checker.execs({{1, 25, 25, 25, 4}, {1, 3, 3}, {1, 25, 25, 510, 4}});
  1205. checker.execs({{1, 25, 25, 25, 4}, {1, 3, 3}, {1, 25, 51, 51, 4}});
  1206. checker.execs({{1, 25, 51, 51, 4}, {1, 3, 3}, {1, 25, 25, 25, 4}});
  1207. }
  1208. }
  1209. } // namespace test
  1210. } // namespace megdnn
  1211. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}