You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nn_int.cpp 28 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671
  1. #include "megbrain/opr/nn_int.h"
  2. #include "megbrain/opr/basic_arith.h"
  3. #include "megbrain/opr/io.h"
  4. #include "megbrain/test/autocheck.h"
  5. #include "megbrain/test/helper.h"
  6. #include "megbrain/test/megdnn_helper.h"
  7. using namespace mgb;
  8. namespace {
  9. using Checker31 = AutoOprChecker<3, 1>;
  10. std::unique_ptr<Checker31> make_elemwise_multi_type_checker3(
  11. opr::ElemwiseMultiType::Mode mode, const std::array<DType, 3>& dtypes) {
  12. using Checker = Checker31;
  13. auto make_graph = [=](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
  14. auto as_type = [&dtypes, &inputs](size_t i) {
  15. return opr::TypeCvt::make(inputs[i], dtypes[i]);
  16. };
  17. auto ovar = opr::ElemwiseMultiType::make(
  18. {as_type(0), as_type(1), as_type(2)}, mode);
  19. return {opr::TypeCvt::make(ovar, dtype::Float32{})};
  20. };
  21. auto fwd = [=](Checker::NumOutArray& dest, Checker::NumInpArray inp) {
  22. auto opr = megdnn_naive_handle()->create_operator<megdnn::ElemwiseMultiType>();
  23. auto opr_typecvt = megdnn_naive_handle()->create_operator<megdnn::TypeCvt>();
  24. opr->param() = {mode};
  25. megdnn::TensorShapeArray inp_shapes(3);
  26. megdnn::TensorNDArray inp_tensors(3);
  27. HostTensorND cvt_val[3];
  28. for (int i = 0; i < 3; ++i) {
  29. cvt_val[i]
  30. .dtype(dtypes[i])
  31. .comp_node(inp[i]->comp_node())
  32. .resize(inp[i]->shape());
  33. opr_typecvt->exec(inp[i]->as_megdnn(), cvt_val[i].as_megdnn());
  34. inp_shapes[i] = inp[i]->shape();
  35. inp_tensors[i] = cvt_val[i].as_megdnn();
  36. }
  37. TensorShape out_shape;
  38. megdnn::Elemwise::deduce_shape(inp_shapes, out_shape);
  39. auto trait = megdnn::ElemwiseMultiType::ModeTrait::from_mode(mode);
  40. DType dtype;
  41. trait.check_out(dtype, false);
  42. HostTensorND tmp_out{inp[0]->comp_node(), out_shape, dtype};
  43. opr->exec(inp_tensors, tmp_out.as_megdnn());
  44. dest[0].resize(out_shape);
  45. opr_typecvt->exec(tmp_out.as_megdnn(), dest[0].as_megdnn());
  46. };
  47. return std::make_unique<Checker>(make_graph, fwd);
  48. }
  49. } // anonymous namespace
  50. TEST(TestOprElemwiseMultiType, Fma3Int16x32x32x32) {
  51. make_elemwise_multi_type_checker3(
  52. opr::ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16x32x32x32,
  53. {dtype::Int16{}, dtype::Int32{}, dtype::Int32{}})
  54. ->disable_grad_check()
  55. .run({TensorShape{3, 4, 5}, {1, 4, 1}, {1, 4, 1}})
  56. .run({TensorShape{1, 4, 5}, {1, 4, 1}, {1, 4, 1}})
  57. .run({TensorShape{3, 4, 5}, {3, 4, 1}, {3, 4, 1}});
  58. }
  59. TEST(TestOprElemwiseMultiType, Fma3IXxf32xf32xi8) {
  60. std::array<DType, 3> src_types{dtype::Int8{}, dtype::Int16{}, dtype::Int32{}};
  61. for (auto src_type : src_types) {
  62. make_elemwise_multi_type_checker3(
  63. opr::ElemwiseMultiType::Mode::FUSE_MUL_ADD3_IXxF32xF32xI8,
  64. {src_type, dtype::Float32{}, dtype::Float32{}})
  65. ->disable_grad_check()
  66. .run({TensorShape{3, 4}, {3, 4}, {3, 4}})
  67. .run({TensorShape{3, 4}, {1, 4}, {1, 4}})
  68. .run({TensorShape{9, 4, 8}, {1, 4, 8}, {1, 4, 8}});
  69. }
  70. }
  71. TEST(TestOprElemwiseMultiType, QuantizedModeBinary_IS8_OS32) {
  72. using Checker = AutoOprChecker<2, 1>;
  73. DType x_dtype = dtype::QuantizedS8(0.15f);
  74. DType y_dtype = dtype::QuantizedS8(0.20f);
  75. DType z_dtype = dtype::QuantizedS32(0.15f);
  76. using Mode = opr::ElemwiseMultiType::Param::Mode;
  77. for (auto mode : {Mode::QFUSE_ADD_RELU, Mode::QADD, Mode::QMUL}) {
  78. auto make_graph =
  79. [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
  80. OperatorNodeConfig config{z_dtype};
  81. auto cpu = CompNode::load("cpux");
  82. auto a = opr::Copy::make(inputs[0], cpu);
  83. auto b = opr::Copy::make(inputs[1], cpu);
  84. auto y = opr::ElemwiseMultiType::make(
  85. {opr::TypeCvt::make(a, x_dtype), opr::TypeCvt::make(b, y_dtype)},
  86. {mode}, config);
  87. y = opr::TypeCvt::make(y, dtype::Float32());
  88. return {y};
  89. };
  90. auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) {
  91. auto cg = ComputingGraph::make();
  92. cg->options().graph_opt_level = 0;
  93. auto x = opr::TypeCvt::make(
  94. opr::Host2DeviceCopy::make(*cg, inp[0]), x_dtype);
  95. auto y = opr::TypeCvt::make(
  96. opr::Host2DeviceCopy::make(*cg, inp[1]), y_dtype);
  97. SymbolVar z;
  98. if (mode == Mode::QMUL) {
  99. z = opr::TypeCvt::make(x, dtype::Float32()) *
  100. opr::TypeCvt::make(y, dtype::Float32());
  101. z = opr::TypeCvt::make(z, z_dtype);
  102. }
  103. if (mode == Mode::QADD) {
  104. z = opr::TypeCvt::make(x, dtype::Float32()) +
  105. opr::TypeCvt::make(y, dtype::Float32());
  106. z = opr::TypeCvt::make(z, z_dtype);
  107. }
  108. if (mode == Mode::QFUSE_ADD_RELU) {
  109. z = opr::TypeCvt::make(x, dtype::Float32()) +
  110. opr::TypeCvt::make(y, dtype::Float32());
  111. z = opr::Elemwise::make({z}, {opr::Elemwise::Mode::RELU});
  112. z = opr::TypeCvt::make(z, z_dtype);
  113. }
  114. z = opr::TypeCvt::make(z, dtype::Float32());
  115. auto func = cg->compile({make_callback_copy(z, dest[0])});
  116. func->execute().wait();
  117. };
  118. Checker checker{make_graph, fwd};
  119. Checker::RunOptions options;
  120. options.outputs_max_err = 0.2;
  121. checker.disable_grad_check()
  122. .run({TensorShape{3, 4}, {3, 4}})
  123. .run({TensorShape{3, 4}, {1, 4}})
  124. .run({TensorShape{9, 4, 8}, {1, 4, 8}}, options);
  125. }
  126. }
  127. auto gen_postive = [](HostTensorND& dest) {
  128. HostTensorGenerator<dtype::Float32, RandomDistribution::UNIFORM> mask_generator{
  129. 0.f, FLT_MAX};
  130. dest = *mask_generator(dest.shape(), dest.comp_node());
  131. };
  132. //! \warning: asin and acos has lower precision,
  133. //! they may produce nan.
  134. auto gen_asin_acos = [](HostTensorND& dest) {
  135. HostTensorGenerator<dtype::Float32, RandomDistribution::UNIFORM> mask_generator{
  136. -0.5f, 0.5f};
  137. dest = *mask_generator(dest.shape(), dest.comp_node());
  138. };
  139. //! \warning: erfinv and erfcinv has lower precision,
  140. //! should give them more strict input.
  141. auto gen_erfinv = [](HostTensorND& dest) {
  142. HostTensorGenerator<dtype::Float32, RandomDistribution::UNIFORM> mask_generator{
  143. -0.5f, 0.5f};
  144. dest = *mask_generator(dest.shape(), dest.comp_node());
  145. };
  146. auto gen_erfcinv = [](HostTensorND& dest) {
  147. HostTensorGenerator<dtype::Float32, RandomDistribution::UNIFORM> mask_generator{
  148. 0.5f, 1.5f};
  149. dest = *mask_generator(dest.shape(), dest.comp_node());
  150. };
  151. #define MAKE_UNARY(_MODE) \
  152. case Mode::Q##_MODE: \
  153. d = opr::Elemwise::make({xf}, {opr::Elemwise::Mode::_MODE}); \
  154. break
  155. TEST(TestOprElemwiseMultiType, QuantizedModeUnary_IS8_OS8) {
  156. using Checker = AutoOprChecker<1, 1>;
  157. DType x_dtype = dtype::QuantizedS8(1.15f);
  158. DType d_dtype = dtype::QuantizedS8(2.00f);
  159. using Mode = opr::ElemwiseMultiType::Param::Mode;
  160. for (auto mode :
  161. {Mode::QRELU, Mode::QABS, Mode::QSIGMOID, Mode::QEXP, Mode::QTANH,
  162. Mode::QNEGATE, Mode::QACOS, Mode::QASIN, Mode::QCEIL, Mode::QCOS,
  163. Mode::QEXPM1, Mode::QFLOOR, Mode::QLOG, Mode::QLOG1P, Mode::QSIN,
  164. Mode::QROUND, Mode::QERF, Mode::QERFINV, Mode::QERFC, Mode::QERFCINV,
  165. Mode::QFAST_TANH, Mode::QH_SWISH}) {
  166. auto make_graph =
  167. [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
  168. OperatorNodeConfig config{d_dtype};
  169. auto cpu = CompNode::load("cpux");
  170. auto a = opr::Copy::make(inputs[0], cpu);
  171. auto d = opr::ElemwiseMultiType::make(
  172. {opr::TypeCvt::make(a, x_dtype)}, {mode}, config);
  173. d = opr::TypeCvt::make(d, dtype::Float32());
  174. return {d};
  175. };
  176. auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) {
  177. auto cg = ComputingGraph::make();
  178. cg->options().graph_opt_level = 0;
  179. auto x = opr::TypeCvt::make(
  180. opr::Host2DeviceCopy::make(*cg, inp[0]), x_dtype);
  181. SymbolVar d;
  182. auto xf = opr::TypeCvt::make(x, dtype::Float32());
  183. switch (mode) {
  184. MAKE_UNARY(RELU);
  185. MAKE_UNARY(ABS);
  186. MAKE_UNARY(SIGMOID);
  187. MAKE_UNARY(EXP);
  188. MAKE_UNARY(TANH);
  189. MAKE_UNARY(FAST_TANH);
  190. MAKE_UNARY(NEGATE);
  191. MAKE_UNARY(ACOS);
  192. MAKE_UNARY(ASIN);
  193. MAKE_UNARY(CEIL);
  194. MAKE_UNARY(COS);
  195. MAKE_UNARY(EXPM1);
  196. MAKE_UNARY(FLOOR);
  197. MAKE_UNARY(LOG);
  198. MAKE_UNARY(LOG1P);
  199. MAKE_UNARY(SIN);
  200. MAKE_UNARY(ROUND);
  201. MAKE_UNARY(ERF);
  202. MAKE_UNARY(ERFINV);
  203. MAKE_UNARY(ERFC);
  204. MAKE_UNARY(ERFCINV);
  205. MAKE_UNARY(H_SWISH);
  206. default:
  207. mgb_throw(InternalError, "Unknown ElemwiseMultiType Mode\n");
  208. break;
  209. }
  210. d = opr::TypeCvt::make(d, d_dtype);
  211. d = opr::TypeCvt::make(d, dtype::Float32());
  212. auto func = cg->compile({make_callback_copy(d, dest[0])});
  213. func->execute().wait();
  214. };
  215. Checker checker{make_graph, fwd};
  216. switch (mode) {
  217. case Mode::QACOS:
  218. case Mode::QASIN:
  219. checker.set_input_generator(0, gen_asin_acos);
  220. break;
  221. case Mode::QLOG:
  222. case Mode::QLOG1P:
  223. checker.set_input_generator(0, gen_postive);
  224. break;
  225. case Mode::QERFINV:
  226. checker.set_input_generator(0, gen_erfinv);
  227. break;
  228. case Mode::QERFCINV:
  229. checker.set_input_generator(0, gen_erfcinv);
  230. break;
  231. default:
  232. break;
  233. }
  234. Checker::RunOptions options;
  235. options.outputs_max_err = 0.2;
  236. checker.disable_grad_check()
  237. .run({TensorShape{3, 4}})
  238. .run({TensorShape{4, 8}})
  239. .run({TensorShape{9, 4, 8}}, options);
  240. }
  241. }
  242. TEST(TestOprElemwiseMultiType, QuantizedModeUnary_I8Asymm_O8Asymm) {
  243. using Checker = AutoOprChecker<1, 1>;
  244. DType x_dtype = dtype::Quantized8Asymm(1.15f, static_cast<uint8_t>(128));
  245. DType d_dtype = dtype::Quantized8Asymm(2.00f, static_cast<uint8_t>(128));
  246. using Mode = opr::ElemwiseMultiType::Param::Mode;
  247. for (auto mode :
  248. {Mode::QRELU, Mode::QABS, Mode::QSIGMOID, Mode::QEXP, Mode::QTANH,
  249. Mode::QNEGATE, Mode::QACOS, Mode::QASIN, Mode::QCEIL, Mode::QCOS,
  250. Mode::QEXPM1, Mode::QFLOOR, Mode::QLOG, Mode::QLOG1P, Mode::QSIN,
  251. Mode::QROUND, Mode::QERF, Mode::QERFINV, Mode::QERFC, Mode::QERFCINV,
  252. Mode::QFAST_TANH}) {
  253. auto make_graph =
  254. [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
  255. OperatorNodeConfig config{d_dtype};
  256. auto cpu = CompNode::load("cpux");
  257. auto a = opr::Copy::make(inputs[0], cpu);
  258. auto d = opr::ElemwiseMultiType::make(
  259. {opr::TypeCvt::make(a, x_dtype)}, {mode}, config);
  260. d = opr::TypeCvt::make(d, dtype::Float32());
  261. return {d};
  262. };
  263. auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) {
  264. auto cg = ComputingGraph::make();
  265. cg->options().graph_opt_level = 0;
  266. auto x = opr::TypeCvt::make(
  267. opr::Host2DeviceCopy::make(*cg, inp[0]), x_dtype);
  268. SymbolVar d;
  269. auto xf = opr::TypeCvt::make(x, dtype::Float32());
  270. switch (mode) {
  271. MAKE_UNARY(RELU);
  272. MAKE_UNARY(ABS);
  273. MAKE_UNARY(SIGMOID);
  274. MAKE_UNARY(EXP);
  275. MAKE_UNARY(TANH);
  276. MAKE_UNARY(FAST_TANH);
  277. MAKE_UNARY(NEGATE);
  278. MAKE_UNARY(ACOS);
  279. MAKE_UNARY(ASIN);
  280. MAKE_UNARY(CEIL);
  281. MAKE_UNARY(COS);
  282. MAKE_UNARY(EXPM1);
  283. MAKE_UNARY(FLOOR);
  284. MAKE_UNARY(LOG);
  285. MAKE_UNARY(LOG1P);
  286. MAKE_UNARY(SIN);
  287. MAKE_UNARY(ROUND);
  288. MAKE_UNARY(ERF);
  289. MAKE_UNARY(ERFINV);
  290. MAKE_UNARY(ERFC);
  291. MAKE_UNARY(ERFCINV);
  292. default:
  293. mgb_throw(InternalError, "Unknown ElemwiseMultiType Mode\n");
  294. break;
  295. }
  296. d = opr::TypeCvt::make(d, d_dtype);
  297. d = opr::TypeCvt::make(d, dtype::Float32());
  298. auto func = cg->compile({make_callback_copy(d, dest[0])});
  299. func->execute().wait();
  300. };
  301. Checker checker{make_graph, fwd};
  302. switch (mode) {
  303. case Mode::QACOS:
  304. case Mode::QASIN:
  305. checker.set_input_generator(0, gen_asin_acos);
  306. break;
  307. case Mode::QLOG:
  308. case Mode::QLOG1P:
  309. checker.set_input_generator(0, gen_postive);
  310. break;
  311. case Mode::QERFINV:
  312. checker.set_input_generator(0, gen_erfinv);
  313. break;
  314. case Mode::QERFCINV:
  315. checker.set_input_generator(0, gen_erfcinv);
  316. break;
  317. default:
  318. break;
  319. }
  320. Checker::RunOptions options;
  321. options.outputs_max_err = 0.2;
  322. checker.disable_grad_check()
  323. .run({TensorShape{3, 4}})
  324. .run({TensorShape{4, 8}})
  325. .run({TensorShape{9, 4, 8}}, options);
  326. }
  327. }
  328. #undef MAKE_UANRY
  329. #define MAKE_BINARY(_MODE) \
  330. case Mode::Q##_MODE: \
  331. d = opr::Elemwise::make({xf, yf}, {opr::Elemwise::Mode::_MODE}); \
  332. break
  333. TEST(TestOprElemwiseMultiType, QuantizedModeBinary_IS8_OS8) {
  334. using Checker = AutoOprChecker<2, 1>;
  335. DType x_dtype = dtype::QuantizedS8(1.15f);
  336. DType y_dtype = dtype::QuantizedS8(2.0f);
  337. DType d_dtype = dtype::QuantizedS8(1.15f);
  338. using Mode = opr::ElemwiseMultiType::Param::Mode;
  339. for (auto mode :
  340. {Mode::QFUSE_ADD_RELU,
  341. Mode::QADD,
  342. Mode::QMUL,
  343. Mode::QMIN,
  344. Mode::QMAX,
  345. Mode::QSUB,
  346. Mode::QTRUE_DIV,
  347. Mode::QFUSE_ADD_SIGMOID,
  348. Mode::QFUSE_ADD_TANH,
  349. Mode::QABS_GRAD,
  350. Mode::QFLOOR_DIV,
  351. Mode::QMOD,
  352. Mode::QSIGMOID_GRAD,
  353. Mode::QSWITCH_GT0,
  354. Mode::QTANH_GRAD,
  355. Mode::QLT,
  356. Mode::QLEQ,
  357. Mode::QEQ,
  358. Mode::QPOW,
  359. Mode::QLOG_SUM_EXP,
  360. Mode::QFAST_TANH_GRAD,
  361. Mode::QATAN2}) {
  362. auto make_graph =
  363. [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
  364. OperatorNodeConfig config{d_dtype};
  365. auto cpu = CompNode::load("cpux");
  366. auto a = opr::Copy::make(inputs[0], cpu);
  367. auto b = opr::Copy::make(inputs[1], cpu);
  368. auto d = opr::ElemwiseMultiType::make(
  369. {opr::TypeCvt::make(a, x_dtype), opr::TypeCvt::make(b, y_dtype)},
  370. {mode}, config);
  371. d = opr::TypeCvt::make(d, dtype::Float32());
  372. return {d};
  373. };
  374. auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) {
  375. auto cg = ComputingGraph::make();
  376. cg->options().graph_opt_level = 0;
  377. auto x = opr::TypeCvt::make(
  378. opr::Host2DeviceCopy::make(*cg, inp[0]), x_dtype);
  379. auto y = opr::TypeCvt::make(
  380. opr::Host2DeviceCopy::make(*cg, inp[1]), y_dtype);
  381. SymbolVar d;
  382. auto xf = opr::TypeCvt::make(x, dtype::Float32());
  383. auto yf = opr::TypeCvt::make(y, dtype::Float32());
  384. switch (mode) {
  385. MAKE_BINARY(FUSE_ADD_RELU);
  386. MAKE_BINARY(ADD);
  387. MAKE_BINARY(MUL);
  388. MAKE_BINARY(MIN);
  389. MAKE_BINARY(MAX);
  390. MAKE_BINARY(SUB);
  391. MAKE_BINARY(TRUE_DIV);
  392. MAKE_BINARY(FUSE_ADD_SIGMOID);
  393. MAKE_BINARY(FUSE_ADD_TANH);
  394. MAKE_BINARY(ABS_GRAD);
  395. MAKE_BINARY(FLOOR_DIV);
  396. MAKE_BINARY(MOD);
  397. MAKE_BINARY(SIGMOID_GRAD);
  398. MAKE_BINARY(SWITCH_GT0);
  399. MAKE_BINARY(TANH_GRAD);
  400. MAKE_BINARY(LT);
  401. MAKE_BINARY(LEQ);
  402. MAKE_BINARY(EQ);
  403. MAKE_BINARY(POW);
  404. MAKE_BINARY(LOG_SUM_EXP);
  405. MAKE_BINARY(FAST_TANH_GRAD);
  406. MAKE_BINARY(ATAN2);
  407. default:
  408. mgb_throw(InternalError, "Unknown ElemwiseMultiType Mode\n");
  409. break;
  410. }
  411. d = opr::TypeCvt::make(d, d_dtype);
  412. d = opr::TypeCvt::make(d, dtype::Float32());
  413. auto func = cg->compile({make_callback_copy(d, dest[0])});
  414. func->execute().wait();
  415. };
  416. Checker checker{make_graph, fwd};
  417. switch (mode) {
  418. case Mode::QTRUE_DIV:
  419. case Mode::QMOD:
  420. case Mode::QFLOOR_DIV:
  421. checker.set_input_generator(1, gen_postive);
  422. break;
  423. default:
  424. break;
  425. }
  426. Checker::RunOptions options;
  427. options.outputs_max_err = 0.2;
  428. checker.disable_grad_check()
  429. .run({TensorShape{3, 4}, {3, 4}})
  430. .run({TensorShape{4, 8}, {1, 1}})
  431. .run({TensorShape{9, 4, 8}, {9, 4, 8}}, options);
  432. }
  433. }
  434. TEST(TestOprElemwiseMultiType, QuantizedModeBinary_I8Asymm_O8Asymm) {
  435. using Checker = AutoOprChecker<2, 1>;
  436. DType x_dtype = dtype::Quantized8Asymm(1.15f, static_cast<uint8_t>(128));
  437. DType y_dtype = dtype::Quantized8Asymm(2.0f, static_cast<uint8_t>(128));
  438. DType d_dtype = dtype::Quantized8Asymm(1.15f, static_cast<uint8_t>(128));
  439. using Mode = opr::ElemwiseMultiType::Param::Mode;
  440. for (auto mode :
  441. {Mode::QFUSE_ADD_RELU,
  442. Mode::QADD,
  443. Mode::QMUL,
  444. Mode::QMIN,
  445. Mode::QMAX,
  446. Mode::QSUB,
  447. Mode::QTRUE_DIV,
  448. Mode::QFUSE_ADD_SIGMOID,
  449. Mode::QFUSE_ADD_TANH,
  450. Mode::QFUSE_ADD_H_SWISH,
  451. Mode::QABS_GRAD,
  452. Mode::QFLOOR_DIV,
  453. Mode::QMOD,
  454. Mode::QSIGMOID_GRAD,
  455. Mode::QSWITCH_GT0,
  456. Mode::QTANH_GRAD,
  457. Mode::QLT,
  458. Mode::QLEQ,
  459. Mode::QEQ,
  460. Mode::QPOW,
  461. Mode::QLOG_SUM_EXP,
  462. Mode::QFAST_TANH_GRAD,
  463. Mode::QATAN2}) {
  464. auto make_graph =
  465. [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
  466. OperatorNodeConfig config{d_dtype};
  467. auto cpu = CompNode::load("cpux");
  468. auto a = opr::Copy::make(inputs[0], cpu);
  469. auto b = opr::Copy::make(inputs[1], cpu);
  470. auto d = opr::ElemwiseMultiType::make(
  471. {opr::TypeCvt::make(a, x_dtype), opr::TypeCvt::make(b, y_dtype)},
  472. {mode}, config);
  473. d = opr::TypeCvt::make(d, dtype::Float32());
  474. return {d};
  475. };
  476. auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) {
  477. auto cg = ComputingGraph::make();
  478. cg->options().graph_opt_level = 0;
  479. auto x = opr::TypeCvt::make(
  480. opr::Host2DeviceCopy::make(*cg, inp[0]), x_dtype);
  481. auto y = opr::TypeCvt::make(
  482. opr::Host2DeviceCopy::make(*cg, inp[1]), y_dtype);
  483. SymbolVar d;
  484. auto xf = opr::TypeCvt::make(x, dtype::Float32());
  485. auto yf = opr::TypeCvt::make(y, dtype::Float32());
  486. switch (mode) {
  487. MAKE_BINARY(FUSE_ADD_RELU);
  488. MAKE_BINARY(ADD);
  489. MAKE_BINARY(MUL);
  490. MAKE_BINARY(MIN);
  491. MAKE_BINARY(MAX);
  492. MAKE_BINARY(SUB);
  493. MAKE_BINARY(TRUE_DIV);
  494. MAKE_BINARY(FUSE_ADD_SIGMOID);
  495. MAKE_BINARY(FUSE_ADD_TANH);
  496. MAKE_BINARY(FUSE_ADD_H_SWISH);
  497. MAKE_BINARY(ABS_GRAD);
  498. MAKE_BINARY(FLOOR_DIV);
  499. MAKE_BINARY(MOD);
  500. MAKE_BINARY(SIGMOID_GRAD);
  501. MAKE_BINARY(SWITCH_GT0);
  502. MAKE_BINARY(TANH_GRAD);
  503. MAKE_BINARY(LT);
  504. MAKE_BINARY(LEQ);
  505. MAKE_BINARY(EQ);
  506. MAKE_BINARY(POW);
  507. MAKE_BINARY(LOG_SUM_EXP);
  508. MAKE_BINARY(FAST_TANH_GRAD);
  509. MAKE_BINARY(ATAN2);
  510. default:
  511. mgb_throw(InternalError, "Unknown ElemwiseMultiType Mode\n");
  512. break;
  513. }
  514. d = opr::TypeCvt::make(d, d_dtype);
  515. d = opr::TypeCvt::make(d, dtype::Float32());
  516. auto func = cg->compile({make_callback_copy(d, dest[0])});
  517. func->execute().wait();
  518. };
  519. Checker checker{make_graph, fwd};
  520. switch (mode) {
  521. case Mode::QTRUE_DIV:
  522. case Mode::QMOD:
  523. case Mode::QFLOOR_DIV:
  524. checker.set_input_generator(1, gen_postive);
  525. break;
  526. default:
  527. break;
  528. }
  529. Checker::RunOptions options;
  530. options.outputs_max_err = 0.2;
  531. checker.disable_grad_check()
  532. .run({TensorShape{3, 4}, {3, 4}})
  533. .run({TensorShape{4, 8}, {1, 1}})
  534. .run({TensorShape{9, 4, 8}, {9, 4, 8}}, options);
  535. }
  536. }
  537. #undef MAKE_BINARY
  538. #define MAKE_TERNARY(_MODE) \
  539. case Mode::Q##_MODE: \
  540. d = opr::Elemwise::make({xf, yf, zf}, {opr::Elemwise::Mode::_MODE}); \
  541. break
  542. TEST(TestOprElemwiseMultiType, QuantizedModeTernary_IS8_OS8) {
  543. using Checker = AutoOprChecker<3, 1>;
  544. DType x_dtype = dtype::QuantizedS8(1.15f);
  545. DType y_dtype = dtype::QuantizedS8(2.0f);
  546. DType z_dtype = dtype::QuantizedS8(1.15f);
  547. DType d_dtype = dtype::QuantizedS8(1.15f);
  548. using Mode = opr::ElemwiseMultiType::Param::Mode;
  549. for (auto mode : {Mode::QFUSE_MUL_ADD3, Mode::QCOND_LEQ_MOV}) {
  550. auto make_graph =
  551. [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
  552. OperatorNodeConfig config{d_dtype};
  553. auto cpu = CompNode::load("cpux");
  554. auto a = opr::Copy::make(inputs[0], cpu);
  555. auto b = opr::Copy::make(inputs[1], cpu);
  556. auto c = opr::Copy::make(inputs[2], cpu);
  557. auto d = opr::ElemwiseMultiType::make(
  558. {opr::TypeCvt::make(a, x_dtype), opr::TypeCvt::make(b, y_dtype),
  559. opr::TypeCvt::make(c, z_dtype)},
  560. {mode}, config);
  561. d = opr::TypeCvt::make(d, dtype::Float32());
  562. return {d};
  563. };
  564. auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) {
  565. auto cg = ComputingGraph::make();
  566. cg->options().graph_opt_level = 0;
  567. auto x = opr::TypeCvt::make(
  568. opr::Host2DeviceCopy::make(*cg, inp[0]), x_dtype);
  569. auto y = opr::TypeCvt::make(
  570. opr::Host2DeviceCopy::make(*cg, inp[1]), y_dtype);
  571. auto z = opr::TypeCvt::make(
  572. opr::Host2DeviceCopy::make(*cg, inp[2]), z_dtype);
  573. SymbolVar d;
  574. auto xf = opr::TypeCvt::make(x, dtype::Float32());
  575. auto yf = opr::TypeCvt::make(y, dtype::Float32());
  576. auto zf = opr::TypeCvt::make(z, dtype::Float32());
  577. switch (mode) {
  578. MAKE_TERNARY(FUSE_MUL_ADD3);
  579. MAKE_TERNARY(COND_LEQ_MOV);
  580. MAKE_TERNARY(COND_LT_MOV);
  581. default:
  582. mgb_throw(InternalError, "Unknown ElemwiseMultiType Mode\n");
  583. break;
  584. }
  585. d = opr::TypeCvt::make(d, d_dtype);
  586. d = opr::TypeCvt::make(d, dtype::Float32());
  587. auto func = cg->compile({make_callback_copy(d, dest[0])});
  588. func->execute().wait();
  589. };
  590. Checker checker{make_graph, fwd};
  591. Checker::RunOptions options;
  592. options.outputs_max_err = 0.2;
  593. checker.disable_grad_check()
  594. .run({TensorShape{3, 4}, {3, 4}, {3, 4}})
  595. .run({TensorShape{4, 8}, {4, 8}, {4, 8}})
  596. .run({TensorShape{9, 4, 8}, {9, 4, 8}, {9, 4, 8}}, options);
  597. }
  598. }
  599. TEST(TestOprElemwiseMultiType, QuantizedModeTernary_I8Asymm_O8Asymm) {
  600. using Checker = AutoOprChecker<3, 1>;
  601. DType x_dtype = dtype::Quantized8Asymm(1.15f, static_cast<uint8_t>(128));
  602. DType y_dtype = dtype::Quantized8Asymm(2.0f, static_cast<uint8_t>(128));
  603. DType z_dtype = dtype::Quantized8Asymm(1.15f, static_cast<uint8_t>(128));
  604. DType d_dtype = dtype::Quantized8Asymm(1.15f, static_cast<uint8_t>(128));
  605. using Mode = opr::ElemwiseMultiType::Param::Mode;
  606. for (auto mode : {Mode::QFUSE_MUL_ADD3, Mode::QCOND_LEQ_MOV}) {
  607. auto make_graph =
  608. [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
  609. OperatorNodeConfig config{d_dtype};
  610. auto cpu = CompNode::load("cpux");
  611. auto a = opr::Copy::make(inputs[0], cpu);
  612. auto b = opr::Copy::make(inputs[1], cpu);
  613. auto c = opr::Copy::make(inputs[2], cpu);
  614. auto d = opr::ElemwiseMultiType::make(
  615. {opr::TypeCvt::make(a, x_dtype), opr::TypeCvt::make(b, y_dtype),
  616. opr::TypeCvt::make(c, z_dtype)},
  617. {mode}, config);
  618. d = opr::TypeCvt::make(d, dtype::Float32());
  619. return {d};
  620. };
  621. auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) {
  622. auto cg = ComputingGraph::make();
  623. cg->options().graph_opt_level = 0;
  624. auto x = opr::TypeCvt::make(
  625. opr::Host2DeviceCopy::make(*cg, inp[0]), x_dtype);
  626. auto y = opr::TypeCvt::make(
  627. opr::Host2DeviceCopy::make(*cg, inp[1]), y_dtype);
  628. auto z = opr::TypeCvt::make(
  629. opr::Host2DeviceCopy::make(*cg, inp[2]), z_dtype);
  630. SymbolVar d;
  631. auto xf = opr::TypeCvt::make(x, dtype::Float32());
  632. auto yf = opr::TypeCvt::make(y, dtype::Float32());
  633. auto zf = opr::TypeCvt::make(z, dtype::Float32());
  634. switch (mode) {
  635. MAKE_TERNARY(FUSE_MUL_ADD3);
  636. MAKE_TERNARY(COND_LEQ_MOV);
  637. MAKE_TERNARY(COND_LT_MOV);
  638. default:
  639. mgb_throw(InternalError, "Unknown ElemwiseMultiType Mode\n");
  640. break;
  641. }
  642. d = opr::TypeCvt::make(d, d_dtype);
  643. d = opr::TypeCvt::make(d, dtype::Float32());
  644. auto func = cg->compile({make_callback_copy(d, dest[0])});
  645. func->execute().wait();
  646. };
  647. Checker checker{make_graph, fwd};
  648. Checker::RunOptions options;
  649. options.outputs_max_err = 0.2;
  650. checker.disable_grad_check()
  651. .run({TensorShape{3, 4}, {3, 4}, {3, 4}})
  652. .run({TensorShape{4, 8}, {4, 8}, {4, 8}})
  653. .run({TensorShape{9, 4, 8}, {9, 4, 8}, {9, 4, 8}}, options);
  654. }
  655. }
  656. #undef MAKE_TERNARY
  657. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}