You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

blas.cpp 32 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929
  1. /**
  2. * \file src/opr/test/blas.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/opr/blas.h"
  12. #include "megbrain/test/helper.h"
  13. #include "megbrain/test/autocheck.h"
  14. #include "megbrain/test/megdnn_helper.h"
  15. #include "megbrain/opr/io.h"
  16. #include "megbrain/opr/tensor_manip.h"
  17. #include "megbrain/comp_node_env.h"
  18. #include "megbrain/opr/basic_arith_wrapper.h"
  19. #include "megbrain/opr/tensor_gen.h"
  20. #include <random>
  21. using namespace mgb;
  22. namespace {
  23. template <typename dt_src, typename dt_dst>
  24. void brute_force_gemm(size_t M, size_t N, size_t K, bool transa, bool transb,
  25. const dt_src* x, const dt_src* y, dt_dst* z) {
  26. for (size_t m = 0; m < M; ++m)
  27. for (size_t n = 0; n < N; ++n) {
  28. dt_dst cur = dt_dst(0);
  29. for (size_t k = 0; k < K; ++k) {
  30. cur += x[transa ? (k * M + m) : (m * K + k)] *
  31. y[transb ? (n * K + k) : (k * N + n)];
  32. }
  33. z[m * N + n] = cur;
  34. }
  35. }
  36. float brute_force_dot(const HostTensorND& a, const HostTensorND& b) {
  37. auto sz = std::max(a.shape(0), b.shape(0));
  38. size_t ap = 0, bp = 0;
  39. float ret = 0;
  40. auto pa = a.ptr<float>(), pb = b.ptr<float>();
  41. auto as = a.layout().stride[0], bs = b.layout().stride[0];
  42. if (a.shape(0) != sz)
  43. as = 0;
  44. if (b.shape(0) != sz)
  45. bs = 0;
  46. for (size_t i = 0; i < sz; ++i) {
  47. ret += pa[ap] * pb[bp];
  48. ap += as;
  49. bp += bs;
  50. }
  51. return ret;
  52. }
  53. // (m,k) * (k,n) = (m,n)
  54. void run_sgemm_test(bool transa, bool transb) {
  55. using Checker = AutoOprChecker<2, 1>;
  56. auto make_graph =
  57. [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
  58. auto param = opr::MatrixMul::Param{transa, transb};
  59. return {opr::MatrixMul::make(inputs[0], inputs[1], param)};
  60. };
  61. auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) {
  62. size_t M, N, K;
  63. M = inp[0]->shape().shape[0];
  64. K = inp[0]->shape().shape[1];
  65. if (transa)
  66. std::swap(M, K);
  67. N = inp[1]->shape().shape[transb ? 0 : 1];
  68. auto z = dest[0].comp_node(inp[0]->comp_node())
  69. .resize({M, N})
  70. .ptr<float>();
  71. // brute-force gemm
  72. brute_force_gemm(M, N, K, transa, transb, inp[0]->ptr<float>(),
  73. inp[1]->ptr<float>(), z);
  74. };
  75. auto mkshp = [](bool trans, size_t m, size_t k) {
  76. TensorShape rst{m, k};
  77. if (trans)
  78. std::swap(rst.shape[0], rst.shape[1]);
  79. return rst;
  80. };
  81. using namespace std::placeholders;
  82. auto mkx = std::bind(mkshp, transa, _1, _2);
  83. auto mky = std::bind(mkshp, transb, _1, _2);
  84. Checker::RunOptions opt;
  85. opt.numdiff_eps = 1;
  86. Checker(make_graph, fwd)
  87. .run({mkx(4, 6), mky(6, 2)}, opt)
  88. .run({mkx(2, 3), mky(3, 100)}, opt)
  89. .run({mkx(20, 3), mky(3, 20)}, opt);
  90. }
  91. #define FWD_BATCH_GEMM(dt_src, dt_dst) \
  92. [transa, transb](Checker::NumOutArray& dest, Checker::NumInpArray inp) { \
  93. bool ta(transa), tb(transb); \
  94. HostTensorND a, b; \
  95. size_t B, M, N, K; \
  96. a.copy_from(*(inp[0])); \
  97. b.copy_from(*(inp[1])); \
  98. B = a.shape().shape[0]; \
  99. M = a.shape().shape[1]; \
  100. K = a.shape().shape[2]; \
  101. N = b.shape().shape[tb ? 1 : 2]; \
  102. if (ta) \
  103. std::swap(M, K); \
  104. auto x = a.ptr<dt_src>(), y = b.ptr<dt_src>(); \
  105. auto z = dest[0].resize({B, M, N}).ptr<dt_dst>(); \
  106. for (size_t b = 0; b < B; ++b) { \
  107. brute_force_gemm(M, N, K, ta, tb, x + b * M * K, y + b * K * N, \
  108. z + b * M * N); \
  109. } \
  110. }
  111. void run_batched_sgemm_test(bool transa, bool transb) {
  112. using Checker = AutoOprChecker<2, 1>;
  113. auto make_graph =
  114. [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
  115. return {opr::BatchedMatrixMul::make(inputs[0], inputs[1],
  116. {transa, transb})};
  117. };
  118. auto fwd = FWD_BATCH_GEMM(float, float);
  119. auto mkshp = [](bool trans, size_t b, size_t m, size_t k) {
  120. TensorShape rst{b, m, k};
  121. if (trans)
  122. std::swap(rst.shape[1], rst.shape[2]);
  123. return rst;
  124. };
  125. using namespace std::placeholders;
  126. auto mkx = std::bind(mkshp, transa, _1, _2, _3);
  127. auto mky = std::bind(mkshp, transb, _1, _2, _3);
  128. Checker::RunOptions opt;
  129. opt.numdiff_eps = 1;
  130. Checker(make_graph, fwd)
  131. .run({mkx(3, 5, 7), mky(3, 7, 2)}, opt)
  132. .run({mkx(64, 1, 2), mky(64, 2, 1)}, opt)
  133. .run({mkx(1, 2, 3), mky(1, 3, 4)}, opt);
  134. }
  135. auto gen_fp16 = [](HostTensorND& dest) {
  136. RNGxorshf rng{next_rand_seed()};
  137. auto rand_real = [&rng]() {
  138. std::uniform_real_distribution<float> dist(-1, 1);
  139. return dist(rng);
  140. };
  141. auto ptr = dest.ptr<dt_float16>();
  142. size_t elems = dest.shape().total_nr_elems();
  143. for (size_t i = 0; i < elems; i++) {
  144. ptr[i] = dt_float16(rand_real());
  145. }
  146. };
  147. auto gen_int8 = [](HostTensorND& dest) {
  148. HostTensorGenerator<dtype::Int8, RandomDistribution::UNIFORM>
  149. int8_generator{-128, 127};
  150. dest = *int8_generator(dest.shape(), dest.comp_node());
  151. };
  152. void run_batched_hgemm_test(bool transa, bool transb) {
  153. using Checker = AutoOprChecker<2, 1>;
  154. auto make_graph =
  155. [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
  156. return {opr::BatchedMatrixMul::make(inputs[0], inputs[1],
  157. {transa, transb})};
  158. };
  159. auto fwd = FWD_BATCH_GEMM(dt_float16, dt_float16);
  160. auto mkshp = [](bool trans, size_t b, size_t m, size_t k) {
  161. TensorShape rst{b, m, k};
  162. if (trans)
  163. std::swap(rst.shape[1], rst.shape[2]);
  164. return rst;
  165. };
  166. using namespace std::placeholders;
  167. auto mkx = std::bind(mkshp, transa, _1, _2, _3);
  168. auto mky = std::bind(mkshp, transb, _1, _2, _3);
  169. Checker checker(make_graph, fwd);
  170. Checker::RunOptions opt;
  171. opt.outputs_max_err = 1e-2;
  172. checker.set_input_dtype(0, dtype::Float16())
  173. .set_input_dtype(1, dtype::Float16())
  174. .set_input_generator(0, gen_fp16)
  175. .set_input_generator(1, gen_fp16)
  176. .set_input_allow_grad(0, false)
  177. .set_input_allow_grad(1, false)
  178. .set_output_allow_grad(0, false);
  179. checker.run({mkx(3, 5, 7), mky(3, 7, 2)}, opt)
  180. .run({mkx(64, 1, 2), mky(64, 2, 1)}, opt)
  181. .run({mkx(1, 2, 3), mky(1, 3, 4)}, opt);
  182. }
  183. void run_batched_igemm_test(bool transa, bool transb) {
  184. using Checker = AutoOprChecker<2, 1>;
  185. auto make_graph =
  186. [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
  187. return {opr::BatchedMatrixMul::make(inputs[0], inputs[1],
  188. {transa, transb})};
  189. };
  190. auto fwd = FWD_BATCH_GEMM(int8_t, int32_t);
  191. auto mkshp = [](bool trans, size_t b, size_t m, size_t k) {
  192. TensorShape rst{b, m, k};
  193. if (trans)
  194. std::swap(rst.shape[1], rst.shape[2]);
  195. return rst;
  196. };
  197. using namespace std::placeholders;
  198. auto mkx = std::bind(mkshp, transa, _1, _2, _3);
  199. auto mky = std::bind(mkshp, transb, _1, _2, _3);
  200. Checker::RunOptions opt;
  201. opt.numdiff_eps = 1;
  202. Checker checker(make_graph, fwd);
  203. checker.set_input_dtype(0, dtype::Int8())
  204. .set_input_dtype(1, dtype::Int8())
  205. .set_input_generator(0, gen_int8)
  206. .set_input_generator(1, gen_int8)
  207. .set_input_allow_grad(0, false)
  208. .set_input_allow_grad(1, false)
  209. .set_output_allow_grad(0, false);
  210. checker.run({mkx(3, 5, 7), mky(3, 7, 2)}, opt)
  211. .run({mkx(64, 1, 2), mky(64, 2, 1)}, opt)
  212. .run({mkx(1, 2, 3), mky(1, 3, 4)}, opt);
  213. }
  214. template <typename ctype>
  215. float getter(ctype val) {
  216. return val;
  217. }
  218. template <>
  219. float getter<dt_qint32>(dt_qint32 val) {
  220. return (float)val.as_int32();
  221. }
  222. template <typename dt_src, typename dt_dst>
  223. void run_trans_inp_test_case(bool trans_a, bool trans_b) {
  224. HostTensorGenerator<typename DTypeTrait<dt_src>::dtype> gen;
  225. std::shared_ptr<HostTensorND> host_x = gen({1, 1}), host_y = gen({1, 1});
  226. auto graph = ComputingGraph::make();
  227. auto do_trans = [](SymbolVar x) {
  228. return opr::Dimshuffle::make(x, {1, 0});
  229. };
  230. auto x = opr::Host2DeviceCopy::make(*graph, host_x),
  231. y = opr::Host2DeviceCopy::make(*graph, host_y);
  232. if (trans_a) {
  233. x = do_trans(x);
  234. }
  235. if (trans_b) {
  236. y = do_trans(y);
  237. }
  238. OperatorNodeConfig config;
  239. if (DTypeTrait<dt_dst>::enumv == DTypeEnum::Int16) {
  240. config.output_dtype(dtype::Int16());
  241. }
  242. auto z = opr::MatrixMul::make(x, y, {}, {}, config);
  243. HostTensorND host_z;
  244. auto func = graph->compile({make_callback_copy(z, host_z)});
  245. auto run = [&](size_t M, size_t K, size_t N) {
  246. *host_x = *(trans_a ? gen({K, M}) : gen({M, K}));
  247. *host_y = *(trans_b ? gen({N, K}) : gen({K, N}));
  248. func->execute();
  249. ASSERT_EQ(TensorShape({M, N}), host_z.shape());
  250. ASSERT_EQ(!trans_a, x.node()->dev_tensor().layout().is_contiguous());
  251. ASSERT_EQ(!trans_b, y.node()->dev_tensor().layout().is_contiguous());
  252. auto px = host_x->ptr<dt_src>(), py = host_y->ptr<dt_src>();
  253. auto pz = host_z.ptr<dt_dst>();
  254. auto make_strd = [](bool trans, int h, int w, int* dst) {
  255. if (trans) {
  256. dst[0] = 1;
  257. dst[1] = h;
  258. } else {
  259. dst[0] = w;
  260. dst[1] = 1;
  261. }
  262. };
  263. int strd_x[2], strd_y[2];
  264. make_strd(trans_a, M, K, strd_x);
  265. make_strd(trans_b, K, N, strd_y);
  266. for (size_t i = 0; i < M; ++i) {
  267. for (size_t j = 0; j < N; ++j) {
  268. dt_dst sum = 0;
  269. for (size_t k = 0; k < K; ++k) {
  270. dt_dst xv = px[i * strd_x[0] + k * strd_x[1]],
  271. yv = py[k * strd_y[0] + j * strd_y[1]];
  272. sum += xv * yv;
  273. }
  274. MGB_ASSERT_FLOAT_EQ(getter(sum), getter(pz[i * N + j]))
  275. << trans_a << ' ' << trans_b;
  276. }
  277. }
  278. };
  279. run(4, 8, 12);
  280. run(8, 12, 16);
  281. }
  282. template <typename dt_src, typename dt_dst>
  283. void run_trans_inp_test() {
  284. for (bool ta : {false, true}) {
  285. for (bool tb : {false, true}) {
  286. run_trans_inp_test_case<dt_src, dt_dst>(ta, tb);
  287. }
  288. }
  289. }
  290. template <typename dt_src, typename dt_dst>
  291. void inline mul_add(dt_src& a, dt_src& b, dt_dst& c) {
  292. c += dt_dst(a) * dt_dst(b);
  293. }
  294. template <>
  295. void inline mul_add(dt_qint8& a, dt_qint8& b, dt_qint32& c) {
  296. c += dt_qint32(a.as_int8()) * dt_qint32(b.as_int8());
  297. }
  298. template <typename dt_gen>
  299. std::shared_ptr<HostTensorND> bgemm_gen(const TensorShape& shp) {
  300. HostTensorGenerator<typename DTypeTrait<dt_gen>::dtype> gen;
  301. return gen(shp);
  302. }
  303. template <>
  304. std::shared_ptr<HostTensorND> bgemm_gen<dt_float16>(const TensorShape& shp) {
  305. CompNode cn = CompNode::load("xpu0");
  306. std::shared_ptr<HostTensorND> ret =
  307. std::make_shared<HostTensorND>(cn, dtype::Float16{});
  308. (*ret).resize(shp);
  309. gen_fp16(*ret);
  310. return ret;
  311. }
  312. template <typename dt_src, typename dt_dst>
  313. void run_bgemm_trans_inp_test_case(bool trans_a, bool trans_b) {
  314. std::shared_ptr<HostTensorND> host_x = bgemm_gen<dt_src>({1, 1, 1}),
  315. host_y = bgemm_gen<dt_src>({1, 1, 1});
  316. auto graph = ComputingGraph::make();
  317. auto x = opr::Host2DeviceCopy::make(*graph, host_x),
  318. y = opr::Host2DeviceCopy::make(*graph, host_y);
  319. trans_a ? (x = opr::Dimshuffle::make(x, {0, 2, 1})) : 0;
  320. trans_b ? (y = opr::Dimshuffle::make(y, {0, 2, 1})) : 0;
  321. auto z = opr::BatchedMatrixMul::make(x, y, {}, {}, OperatorNodeConfig{});
  322. HostTensorND host_z;
  323. auto func = graph->compile({make_callback_copy(z, host_z)});
  324. auto run = [&](size_t B, size_t M, size_t K, size_t N) {
  325. *host_x = *(trans_a ? bgemm_gen<dt_src>({B, K, M})
  326. : bgemm_gen<dt_src>({B, M, K}));
  327. *host_y = *(trans_b ? bgemm_gen<dt_src>({B, N, K})
  328. : bgemm_gen<dt_src>({B, K, N}));
  329. func->execute();
  330. ASSERT_EQ(TensorShape({B, M, N}), host_z.shape());
  331. ASSERT_EQ(!trans_a, x.node()->dev_tensor().layout().is_contiguous());
  332. ASSERT_EQ(!trans_b, y.node()->dev_tensor().layout().is_contiguous());
  333. int strd_x[3], strd_y[3];
  334. auto px = host_x->ptr<dt_src>(), py = host_y->ptr<dt_src>();
  335. auto pz = host_z.ptr<dt_dst>();
  336. auto make_strd = [](bool trans, int h, int w, int* dst) {
  337. dst[0] = h * w;
  338. dst[1] = trans ? 1 : w;
  339. dst[2] = trans ? h : 1;
  340. };
  341. make_strd(trans_a, M, K, strd_x);
  342. make_strd(trans_b, K, N, strd_y);
  343. for (size_t b = 0; b < B; ++b)
  344. for (size_t i = 0; i < M; ++i)
  345. for (size_t j = 0; j < N; ++j) {
  346. dt_dst sum = dt_dst(0);
  347. for (size_t k = 0; k < K; ++k) {
  348. dt_src xv = px[b * strd_x[0] + i * strd_x[1] +
  349. k * strd_x[2]],
  350. yv = py[b * strd_y[0] + k * strd_y[1] +
  351. j * strd_y[2]];
  352. mul_add(xv, yv, sum);
  353. }
  354. MGB_ASSERT_FLOAT_NEAR(getter(sum),
  355. getter(pz[(b * M + i) * N + j]), 5e-3)
  356. << trans_a << ' ' << trans_b;
  357. }
  358. };
  359. run(2, 4, 8, 12);
  360. run(2, 8, 12, 16);
  361. }
  362. } // anonymous namespace
  363. TEST(TestOprBlas, MatrixMul_NN) {
  364. run_sgemm_test(false, false);
  365. }
  366. TEST(TestOprBlas, MatrixMul_NT) {
  367. run_sgemm_test(false, true);
  368. }
  369. TEST(TestOprBlas, MatrixMul_TN) {
  370. run_sgemm_test(true, false);
  371. }
  372. TEST(TestOprBlas, MatrixMul_TT) {
  373. run_sgemm_test(true, true);
  374. }
  375. TEST(TestOprDNN, MatrixMulExePolicy) {
  376. using Param = opr::MatrixMul::Param;
  377. Param param;
  378. using Policy = opr::MatrixMul::ExecutionPolicy;
  379. using S = Policy::Strategy;
  380. auto cn = CompNode::load("cpux");
  381. #if MGB_ENABLE_FASTRUN
  382. for (auto strategy :
  383. SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE,
  384. S::PROFILE | S::HEURISTIC}) {
  385. #else
  386. for (auto strategy: {S:HEURISTIC, S::PROFILE | S::HEURISTIC}) {
  387. #endif
  388. auto graph = ComputingGraph::make();
  389. HostTensorGenerator<> gen;
  390. auto mkvar = [&](const char* name, const TensorShape& shp) {
  391. return opr::Host2DeviceCopy::make(*graph, gen(shp), cn)
  392. .rename(name);
  393. };
  394. auto A = mkvar("A", {32, 64});
  395. auto B = mkvar("B", {64, 32});
  396. Policy policy;
  397. policy.strategy = strategy;
  398. auto C = opr::MatrixMul::make(A, B, param, policy);
  399. HostTensorND host_c;
  400. auto func = graph->compile({make_callback_copy(C, host_c)});
  401. func->execute();
  402. }
  403. }
  404. TEST(TestOprBlas, BatchedMatrixMulFp32_NN) {
  405. run_batched_sgemm_test(false, false);
  406. }
  407. TEST(TestOprBlas, BatchedMatrixMulFp32_NT) {
  408. run_batched_sgemm_test(false, true);
  409. }
  410. TEST(TestOprBlas, BatchedMatrixMulFp32_TN) {
  411. run_batched_sgemm_test(true, false);
  412. }
  413. TEST(TestOprBlas, BatchedMatrixMulFp32_TT) {
  414. run_batched_sgemm_test(true, true);
  415. }
  416. TEST(TestOprBlas, BatchedMatrixMulFp16_NN) {
  417. run_batched_hgemm_test(false, false);
  418. }
  419. TEST(TestOprBlas, BatchedMatrixMulFp16_NT) {
  420. run_batched_hgemm_test(false, true);
  421. }
  422. TEST(TestOprBlas, BatchedMatrixMulFp16_TN) {
  423. run_batched_hgemm_test(true, false);
  424. }
  425. TEST(TestOprBlas, BatchedMatrixMulFp16_TT) {
  426. run_batched_hgemm_test(true, true);
  427. }
  428. TEST(TestOprBlas, BatchedMatrixMulInt8_NN) {
  429. if (CompNode::load("xpux").device_type() == CompNode::DeviceType::CUDA &&
  430. !check_compute_capability(6, 1)) {
  431. return;
  432. }
  433. run_batched_igemm_test(false, false);
  434. }
  435. TEST(TestOprBlas, BatchedMatrixMulInt8_NT) {
  436. if (CompNode::load("xpux").device_type() == CompNode::DeviceType::CUDA &&
  437. !check_compute_capability(6, 1)) {
  438. return;
  439. }
  440. run_batched_igemm_test(false, true);
  441. }
  442. TEST(TestOprBlas, BatchedMatrixMulInt8_TN) {
  443. if (CompNode::load("xpux").device_type() == CompNode::DeviceType::CUDA &&
  444. !check_compute_capability(6, 1)) {
  445. return;
  446. }
  447. run_batched_igemm_test(true, false);
  448. }
  449. TEST(TestOprBlas, BatchedMatrixMulInt8_TT) {
  450. if (CompNode::load("xpux").device_type() == CompNode::DeviceType::CUDA &&
  451. !check_compute_capability(6, 1)) {
  452. return;
  453. }
  454. run_batched_igemm_test(true, true);
  455. }
  456. TEST(TestOprBlas, TransBatchedMatrixMulFp32_NN) {
  457. run_bgemm_trans_inp_test_case<float, float>(false, false);
  458. }
  459. TEST(TestOprBlas, TransBatchedMatrixMulFp32_NT) {
  460. run_bgemm_trans_inp_test_case<float, float>(false, true);
  461. }
  462. TEST(TestOprBlas, TransBatchedMatrixMulFp32_TN) {
  463. run_bgemm_trans_inp_test_case<float, float>(true, false);
  464. }
  465. TEST(TestOprBlas, TransBatchedMatrixMulFp32_TT) {
  466. run_bgemm_trans_inp_test_case<float, float>(true, true);
  467. }
  468. TEST(TestOprBlas, TransBatchedMatrixMulInt8_NN) {
  469. if (CompNode::load("xpux").device_type() == CompNode::DeviceType::CUDA &&
  470. !check_compute_capability(6, 1)) {
  471. return;
  472. }
  473. run_bgemm_trans_inp_test_case<int8_t, int32_t>(false, false);
  474. }
  475. TEST(TestOprBlas, TransBatchedMatrixMulInt8_NT) {
  476. if (CompNode::load("xpux").device_type() == CompNode::DeviceType::CUDA &&
  477. !check_compute_capability(6, 1)) {
  478. return;
  479. }
  480. run_bgemm_trans_inp_test_case<int8_t, int32_t>(false, true);
  481. }
  482. TEST(TestOprBlas, TransBatchedMatrixMulInt8_TN) {
  483. if (CompNode::load("xpux").device_type() == CompNode::DeviceType::CUDA &&
  484. !check_compute_capability(6, 1)) {
  485. return;
  486. }
  487. run_bgemm_trans_inp_test_case<int8_t, int32_t>(true, false);
  488. }
  489. TEST(TestOprBlas, TransBatchedMatrixMulInt8_TT) {
  490. if (CompNode::load("xpux").device_type() == CompNode::DeviceType::CUDA &&
  491. !check_compute_capability(6, 1)) {
  492. return;
  493. }
  494. run_bgemm_trans_inp_test_case<int8_t, int32_t>(true, true);
  495. }
  496. TEST(TestOprBlas, TransBatchedMatrixMulFp16_NN) {
  497. run_bgemm_trans_inp_test_case<dt_float16, dt_float16>(false, false);
  498. }
  499. TEST(TestOprBlas, TransBatchedMatrixMulFp16_NT) {
  500. run_bgemm_trans_inp_test_case<dt_float16, dt_float16>(false, true);
  501. }
  502. TEST(TestOprBlas, TransBatchedMatrixMulFp16_TN) {
  503. run_bgemm_trans_inp_test_case<dt_float16, dt_float16>(true, false);
  504. }
  505. TEST(TestOprBlas, TransBatchedMatrixMulFp16_TT) {
  506. run_bgemm_trans_inp_test_case<dt_float16, dt_float16>(true, true);
  507. }
  508. TEST(TestOprBlas, TransBatchedMatrixMulQS8_NN) {
  509. if (CompNode::load("xpux").device_type() == CompNode::DeviceType::CUDA &&
  510. !check_compute_capability(6, 1)) {
  511. return;
  512. }
  513. run_bgemm_trans_inp_test_case<dt_qint8, dt_qint32>(false, false);
  514. }
  515. TEST(TestOprBlas, TransBatchedMatrixMulQS8_NT) {
  516. if (CompNode::load("xpux").device_type() == CompNode::DeviceType::CUDA &&
  517. !check_compute_capability(6, 1)) {
  518. return;
  519. }
  520. run_bgemm_trans_inp_test_case<dt_qint8, dt_qint32>(false, true);
  521. }
  522. TEST(TestOprBlas, TransBatchedMatrixMulQS8_TN) {
  523. if (CompNode::load("xpux").device_type() == CompNode::DeviceType::CUDA &&
  524. !check_compute_capability(6, 1)) {
  525. return;
  526. }
  527. run_bgemm_trans_inp_test_case<dt_qint8, dt_qint32>(true, false);
  528. }
  529. TEST(TestOprBlas, TransBatchedMatrixMulQS8_TT) {
  530. if (CompNode::load("xpux").device_type() == CompNode::DeviceType::CUDA &&
  531. !check_compute_capability(6, 1)) {
  532. return;
  533. }
  534. run_bgemm_trans_inp_test_case<dt_qint8, dt_qint32>(true, true);
  535. }
  536. TEST(TestOprBlas, DotBasic) {
  537. HostTensorGenerator<> gen;
  538. auto host_x = gen({123}), host_y = gen({123});
  539. auto graph = ComputingGraph::make();
  540. auto x = opr::Host2DeviceCopy::make(*graph, host_x),
  541. y = opr::Host2DeviceCopy::make(*graph, host_y),
  542. z = opr::Dot::make(x, y);
  543. HostTensorND host_z;
  544. auto func = graph->compile({make_callback_copy(z, host_z)});
  545. func->execute();
  546. MGB_ASSERT_FLOAT_EQ(brute_force_dot(*host_x, *host_y),
  547. *host_z.ptr<float>());
  548. }
  549. TEST(TestOprBlas, Dot) {
  550. using Checker = AutoOprChecker<2, 1>;
  551. auto make_graph =
  552. [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
  553. return {opr::Dot::make(inputs[0], inputs[1])};
  554. };
  555. auto fwd = [](Checker::NumOutArray& dest, Checker::NumInpArray inp) {
  556. auto &&i0 = *inp[0], &&i1 = *inp[1];
  557. auto&& out = dest[0].resize({1});
  558. *out.ptr<float>() = brute_force_dot(i0, i1);
  559. };
  560. Checker(make_graph, fwd)
  561. .run({TensorShape{15}, TensorShape{1}})
  562. .run({TensorShape{1}, TensorShape{16}})
  563. .run({TensorShape{23}, TensorShape{23}})
  564. .run({TensorShape{1000}, TensorShape{1000}});
  565. }
  566. TEST(TestOprBlas, TransMatMul) {
  567. run_trans_inp_test<float, float>();
  568. }
  569. TEST(TestOprBlas, TransMatMul8x8x16) {
  570. if (CompNode::load("xpux").device_type() != CompNode::DeviceType::CUDA) {
  571. run_trans_inp_test<dt_int8, dt_int16>();
  572. } else {
  573. printf("testcase skipped on unsupported arch\n");
  574. }
  575. }
  576. TEST(TestOprBlas, TransMatMul8x8x32) {
  577. if (CompNode::load("xpux").device_type() == CompNode::DeviceType::CUDA &&
  578. !check_compute_capability(6, 1)) {
  579. return;
  580. }
  581. run_trans_inp_test<dt_int8, dt_int32>();
  582. }
  583. TEST(TestOprBlas, NonContigMatmul) {
  584. using Checker = AutoOprChecker<2, 1>;
  585. auto make_graph =
  586. [](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
  587. using Ad = opr::Subtensor::AxisIndexer;
  588. auto x = inputs[0],
  589. xsub = opr::Subtensor::make(
  590. x, {Ad::make_interval(0, None, None, x.make_scalar(2))}),
  591. y = inputs[1],
  592. ysub = opr::Subtensor::make(
  593. y, {Ad::make_interval(1, None, None, x.make_scalar(3))});
  594. return {opr::MatrixMul::make(xsub, ysub)};
  595. };
  596. auto fwd = [](Checker::NumOutArray& dest, Checker::NumInpArray inp) {
  597. auto &&shp0 = inp[0]->shape(), &&shp1 = inp[1]->shape();
  598. size_t m = (shp0.shape[0] + 1) / 2, k = shp0.shape[1],
  599. n = (shp1.shape[1] + 2) / 3;
  600. auto dptr = dest[0].resize({m, n}).ptr<float>();
  601. memset(dptr, 0, sizeof(float) * m * n);
  602. for (size_t i = 0; i < m; ++i) {
  603. auto ptr_a = inp[0]->ptr<float>({i * 2}),
  604. ptr_c = dest[0].ptr<float>({i});
  605. for (size_t kk = 0; kk < k; ++kk) {
  606. auto va = ptr_a[kk];
  607. auto ptr_b = inp[1]->ptr<float>({kk});
  608. for (size_t j = 0; j < n; ++j) {
  609. ptr_c[j] += va * ptr_b[j * 3];
  610. }
  611. }
  612. }
  613. };
  614. Checker(make_graph, fwd)
  615. .run({TensorShape{2, 1}, TensorShape{1, 3}})
  616. .run({TensorShape{5, 2}, TensorShape{2, 6}})
  617. .run({TensorShape{6, 3}, TensorShape{3, 8}});
  618. }
  619. TEST(TestOprBlas, MatrixInverse) {
  620. using Checker = AutoOprChecker<1, 1>;
  621. auto make_graph =
  622. [=](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
  623. return {opr::MatrixInverse::make(inputs[0])};
  624. };
  625. auto fwd = [=](Checker::NumOutArray& dest, Checker::NumInpArray inp) {
  626. auto opr =
  627. megdnn_naive_handle()->create_operator<megdnn::MatrixInverse>();
  628. auto wk_size =
  629. opr->get_workspace_in_bytes(inp[0]->layout(), inp[0]->layout());
  630. std::unique_ptr<dt_byte[]> wk{new dt_byte[wk_size]};
  631. opr->exec(inp[0]->as_megdnn(),
  632. dest[0].resize(inp[0]->shape()).as_megdnn(),
  633. {wk.get(), wk_size});
  634. };
  635. // ensure low condition number for generated matrices
  636. auto input_coord = [](const Checker::NumInpArray& inp) {
  637. auto shp = inp[0]->shape();
  638. size_t n = shp[shp.ndim - 1];
  639. size_t batch = 1;
  640. for (size_t i = 0; i < shp.ndim - 2; ++i) {
  641. batch *= shp[i];
  642. }
  643. std::vector<int> perm(n);
  644. for (size_t i = 0; i < n; ++i) {
  645. perm[i] = i;
  646. }
  647. auto ptr = inp[0]->ptr<float>();
  648. for (size_t i = 0; i < batch; ++i, ptr += n * n) {
  649. #if __cplusplus >= 201703L
  650. std::default_random_engine rng_engine;
  651. std::shuffle(perm.begin(), perm.end(), rng_engine);
  652. #else
  653. std::random_shuffle(perm.begin(), perm.end());
  654. #endif
  655. for (size_t j = 0; j < n; ++j) {
  656. ptr[j * n + perm[j]] += 5;
  657. }
  658. }
  659. };
  660. Checker{make_graph, fwd}
  661. .set_input_coordinator(input_coord)
  662. .run({TensorShape{5, 5}})
  663. .run({TensorShape{2, 5, 5}})
  664. .run({TensorShape{2, 6, 3, 3}});
  665. }
  666. namespace {
  667. void gen_svd_input(HostTensorND& dest) {
  668. auto ptr = dest.ptr<float>();
  669. auto dim = dest.layout().ndim;
  670. size_t n = dest.layout().shape[dim - 2], m = dest.layout().shape[dim - 1];
  671. size_t j = 0, k = 0;
  672. float batch_off = 0;
  673. float max_val = std::min(m, n) * std::min(m, n) + 0.99;
  674. for (size_t i = 0, it = dest.layout().total_nr_elems(); i < it; ++i) {
  675. if (i % (n * m) == 0) {
  676. batch_off += 0.32;
  677. j = k = 0;
  678. }
  679. if (!((i % (n * m)) % (m + 1)))
  680. ptr[i] = (j++) + ((++k / 10.0));
  681. else
  682. ptr[i] = (j++);
  683. ptr[i] += batch_off;
  684. ptr[i] = std::fmod(ptr[i], max_val);
  685. }
  686. }
  687. template <int have_u, int have_s, int have_v>
  688. void run_svd_empty_grad_test() {
  689. using Checker = AutoOprChecker<1, have_u + have_s + have_v>;
  690. auto make_graph = [=](const typename Checker::SymInpArray& inputs) {
  691. auto out = opr::SVD::make(inputs[0], opr::SVD::Param{false, true});
  692. typename Checker::SymOutArray ret;
  693. int idx = 0;
  694. if (have_u) {
  695. ret[idx++] = out[0];
  696. }
  697. if (have_s) {
  698. ret[idx++] = out[1];
  699. }
  700. if (have_v) {
  701. ret[idx++] = out[2];
  702. }
  703. return ret;
  704. };
  705. auto fwd = [=](typename Checker::NumOutArray& dest,
  706. typename Checker::NumInpArray inp) {
  707. auto opr = megdnn_naive_handle()->create_operator<megdnn::SVDForward>();
  708. opr->param().compute_uv = true;
  709. TensorLayout ul, sl, vtl;
  710. opr->deduce_layout(inp[0]->layout(), ul, sl, vtl);
  711. HostTensorND tmp_u{dest[0].comp_node(), ul},
  712. tmp_s{dest[0].comp_node(), sl}, tmp_v{dest[0].comp_node(), vtl};
  713. auto wk_size =
  714. opr->get_workspace_in_bytes(inp[0]->layout(), ul, sl, vtl);
  715. auto wk = std::make_unique<dt_byte[]>(wk_size);
  716. auto out0 = tmp_u.as_megdnn(), out1 = tmp_s.as_megdnn(),
  717. out2 = tmp_v.as_megdnn();
  718. int idx = 0;
  719. if (have_u) {
  720. out0 = dest[idx++].resize(ul).as_megdnn();
  721. }
  722. if (have_s) {
  723. out1 = dest[idx++].resize(sl).as_megdnn();
  724. }
  725. if (have_v) {
  726. out2 = dest[idx++].resize(vtl).as_megdnn();
  727. }
  728. opr->exec(inp[0]->as_megdnn(), out0, out1, out2, {wk.get(), wk_size});
  729. };
  730. Checker checker{make_graph, fwd};
  731. checker.set_input_generator(0, gen_svd_input);
  732. if (have_u) {
  733. checker.set_output_allow_check(0, false);
  734. }
  735. if (have_v) {
  736. checker.set_output_allow_check(have_u + have_s, false);
  737. }
  738. checker.run({TensorShape{3, 3}})
  739. .run({TensorShape{2, 3, 3}})
  740. .run({TensorShape{2, 4, 2}})
  741. .run({TensorShape{3, 1, 2, 4}})
  742. .run({TensorShape{2, 3, 2, 3}});
  743. }
  744. } // anonymous namespace
  745. TEST(TestOprBlas, SingularValueDecomposition) {
  746. using Checker = AutoOprChecker<1, 3>;
  747. auto make_graph =
  748. [=](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
  749. auto out = opr::SVD::make(inputs[0], opr::SVD::Param{false, true});
  750. return {out[0], out[1], out[2]};
  751. };
  752. auto fwd = [=](Checker::NumOutArray& dest, Checker::NumInpArray inp) {
  753. auto opr = megdnn_naive_handle()->create_operator<megdnn::SVDForward>();
  754. opr->param().compute_uv = true;
  755. TensorLayout ul, sl, vtl;
  756. opr->deduce_layout(inp[0]->layout(), ul, sl, vtl);
  757. auto wk_size =
  758. opr->get_workspace_in_bytes(inp[0]->layout(), ul, sl, vtl);
  759. auto wk = std::make_unique<dt_byte[]>(wk_size);
  760. opr->exec(inp[0]->as_megdnn(), dest[0].resize(ul).as_megdnn(),
  761. dest[1].resize(sl).as_megdnn(),
  762. dest[2].resize(vtl).as_megdnn(), {wk.get(), wk_size});
  763. };
  764. Checker{make_graph, fwd}
  765. .set_input_generator(0, gen_svd_input)
  766. .set_output_allow_check(0, false)
  767. .set_output_allow_check(2, false)
  768. .run({TensorShape{3, 3}})
  769. .run({TensorShape{2, 3, 3}})
  770. .run({TensorShape{2, 4, 2}})
  771. .run({TensorShape{3, 1, 2, 4}})
  772. .run({TensorShape{2, 3, 2, 3}});
  773. }
  774. TEST(TestOprBlas, SingularValueDecompositionZeroGrad) {
  775. run_svd_empty_grad_test<0, 0, 1>();
  776. run_svd_empty_grad_test<0, 1, 0>();
  777. run_svd_empty_grad_test<0, 1, 1>();
  778. run_svd_empty_grad_test<1, 0, 0>();
  779. run_svd_empty_grad_test<1, 0, 1>();
  780. run_svd_empty_grad_test<1, 1, 0>();
  781. run_svd_empty_grad_test<1, 1, 1>();
  782. }
  783. #if MGB_ENABLE_FASTRUN
  784. TEST(TestOprBlas, MatrixMulExePolicy) {
  785. using Param = opr::MatrixMul::Param;
  786. Param param;
  787. using Policy = opr::MatrixMul::ExecutionPolicy;
  788. using S = Policy::Strategy;
  789. Policy policy;
  790. policy.strategy = S::PROFILE;
  791. auto cn = CompNode::load("cpux");
  792. int nr_get = 0;
  793. auto on_get = [&nr_get](const std::string&, const void*, size_t,
  794. const void*, size_t) { ++nr_get; };
  795. PersistentCacheHook cache_hook{on_get};
  796. auto graph = ComputingGraph::make();
  797. HostTensorGenerator<> gen;
  798. auto mkvar = [&](const char* name, const TensorShape& shp) {
  799. return opr::Host2DeviceCopy::make(*graph, gen(shp), cn).rename(name);
  800. };
  801. auto a = mkvar("a", {20, 50});
  802. auto b = mkvar("b", {50, 40});
  803. auto matmul = opr::MatrixMul::make(a, b, param, policy, {});
  804. HostTensorND host_y;
  805. graph->options().no_profiling_on_shape_change = true;
  806. auto func = graph->compile({make_callback_copy(matmul, host_y)});
  807. func->execute();
  808. ASSERT_GT(nr_get, 0);
  809. int nr = nr_get;
  810. graph->options().no_profiling_on_shape_change = false;
  811. func = graph->compile({make_callback_copy(matmul, host_y)});
  812. func->execute();
  813. ASSERT_GT(nr_get, nr);
  814. }
  815. #endif
  816. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
  817. //

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台