You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_mul.cpp 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. #include "test/naive/fixture.h"
  2. #include "megdnn/oprs/linalg.h"
  3. #include "test/common/checker.h"
  4. #include "test/common/extra_impl_helper.h"
  5. #include "test/common/matrix_mul.h"
  6. #include "test/common/random_state.h"
  7. using namespace megdnn;
  8. using namespace test;
  9. namespace {
  10. void run_matmul_mk_format(
  11. Handle* handle, param::MatrixMul::Format format, DType Atype, DType Btype,
  12. DType Ctype) {
  13. using namespace matrix_mul;
  14. std::vector<TestArg> args = get_matmul_args();
  15. Checker<MatrixMul> checker(handle);
  16. auto extra_impl = [](const TensorNDArray& tensors, param::MatrixMul param,
  17. Handle* handle, size_t pack_size) {
  18. megdnn_assert(
  19. (param.format == param::MatrixMul::Format::MK4 ||
  20. param.format == param::MatrixMul::Format::MK4_DOT ||
  21. param.format == param::MatrixMul::Format::MK8) &&
  22. tensors.size() == 3);
  23. param::MatrixMul new_param = param;
  24. new_param.format = param::MatrixMul::Format::DEFAULT;
  25. size_t M = tensors[2].layout[0] * pack_size;
  26. size_t N = tensors[2].layout[1];
  27. size_t K = tensors[0].layout[1 - param.transposeA] * pack_size;
  28. TensorLayoutArray default_layouts, mk4_layouts;
  29. if (param.transposeA) {
  30. default_layouts.emplace_back(tensors[0].layout.reshape({K, M}));
  31. if (param.format == param::MatrixMul::Format::MK4_DOT) {
  32. mk4_layouts.emplace_back(default_layouts.back()
  33. .reshape(
  34. {K / pack_size, M / pack_size,
  35. pack_size, pack_size})
  36. .dimshuffle({0, 3, 1, 2}));
  37. } else {
  38. mk4_layouts.emplace_back(default_layouts.back()
  39. .reshape(
  40. {K / pack_size, M / pack_size,
  41. pack_size, pack_size})
  42. .dimshuffle({0, 2, 1, 3}));
  43. }
  44. } else {
  45. default_layouts.emplace_back(tensors[0].layout.reshape({M, K}));
  46. if (param.format == param::MatrixMul::Format::MK4_DOT) {
  47. mk4_layouts.emplace_back(default_layouts.back()
  48. .reshape(
  49. {M / pack_size, K / pack_size,
  50. pack_size, pack_size})
  51. .dimshuffle({0, 2, 1, 3}));
  52. } else {
  53. mk4_layouts.emplace_back(default_layouts.back()
  54. .reshape(
  55. {M / pack_size, K / pack_size,
  56. pack_size, pack_size})
  57. .dimshuffle({0, 3, 1, 2}));
  58. }
  59. }
  60. if (param.transposeB) {
  61. default_layouts.emplace_back(tensors[1].layout.reshape({N, K}));
  62. mk4_layouts.emplace_back(default_layouts.back()
  63. .reshape({N, K / pack_size, pack_size})
  64. .dimshuffle({0, 1, 2}));
  65. } else {
  66. default_layouts.emplace_back(tensors[1].layout.reshape({K, N}));
  67. mk4_layouts.emplace_back(default_layouts.back()
  68. .reshape({K / pack_size, N, pack_size})
  69. .dimshuffle({0, 2, 1}));
  70. }
  71. default_layouts.emplace_back(tensors[2].layout.reshape({M, N}));
  72. mk4_layouts.emplace_back(default_layouts.back()
  73. .reshape({M / pack_size, N, pack_size})
  74. .dimshuffle({0, 2, 1}));
  75. auto matmul_opr = handle->create_operator<MatrixMul>();
  76. matmul_opr->param() = new_param;
  77. size_t matmul_workspace = matmul_opr->get_workspace_in_bytes(
  78. default_layouts[0], default_layouts[1], default_layouts[2]);
  79. auto relayout_opr = handle->create_operator<Relayout>();
  80. WorkspaceBundle wb(
  81. nullptr, {default_layouts[0].span().dist_byte(),
  82. default_layouts[1].span().dist_byte(),
  83. default_layouts[2].span().dist_byte(), matmul_workspace});
  84. wb.set(malloc(wb.total_size_in_bytes()));
  85. TensorNDArray default_tensors, mk4_tensors;
  86. for (size_t i = 0; i < 3; i++) {
  87. default_tensors.emplace_back(wb.get(i), default_layouts[i]);
  88. mk4_tensors.emplace_back(tensors[i].raw_ptr(), mk4_layouts[i]);
  89. }
  90. relayout_opr->exec(mk4_tensors[0], default_tensors[0]);
  91. relayout_opr->exec(mk4_tensors[1], default_tensors[1]);
  92. matmul_opr->exec(
  93. default_tensors[0], default_tensors[1], default_tensors[2],
  94. wb.get_workspace(3));
  95. relayout_opr->exec(default_tensors[2], mk4_tensors[2]);
  96. free(wb.ptr());
  97. };
  98. size_t pack_size = MatrixMulForward::pack_size(format);
  99. for (auto&& arg : args) {
  100. if (arg.m % pack_size != 0 || arg.k % pack_size != 0)
  101. continue;
  102. param::MatrixMul param;
  103. param.transposeA = arg.mask & 0x1;
  104. param.transposeB = arg.mask & 0x2;
  105. param.format = format;
  106. size_t m = arg.m, n = arg.n, k = arg.k;
  107. TensorShape A, B;
  108. if (param.transposeA) {
  109. A = TensorShape{k / pack_size, m / pack_size, pack_size, pack_size};
  110. } else {
  111. A = TensorShape{m / pack_size, k / pack_size, pack_size, pack_size};
  112. }
  113. if (param.transposeB) {
  114. B = TensorShape{n, k / pack_size, pack_size};
  115. } else {
  116. B = TensorShape{k / pack_size, n, pack_size};
  117. }
  118. checker.set_extra_opr_impl(
  119. std::bind(extra_impl, std::placeholders::_1, param, handle, pack_size));
  120. checker.set_dtype(0, Atype)
  121. .set_dtype(1, Btype)
  122. .set_dtype(2, Ctype)
  123. .set_epsilon(1e-3)
  124. .set_param(param)
  125. .execs({A, B, {}});
  126. }
  127. }
  128. } // namespace
  129. TEST_F(NAIVE, MATRIX_MUL_QUANTIZED4x4x32) {
  130. Checker<MatrixMul> checker(handle(), /* check_dispatch */ false);
  131. auto GenTensorValueQuint4 = [](const TensorShape& shape,
  132. dtype::Quantized4Asymm dtype,
  133. const std::vector<int>& values) {
  134. TensorND tensor;
  135. tensor.layout = {shape, dtype};
  136. tensor.reset_ptr(
  137. static_cast<dt_byte*>(malloc(tensor.layout.span().dist_byte())));
  138. uint8_t* ptr = static_cast<uint8_t*>(tensor.raw_ptr());
  139. megdnn_assert(values.size() == tensor.layout.span().dist_elem());
  140. for (size_t i = 0; i < tensor.layout.span().dist_elem(); i += 2) {
  141. int val0 = values[i], val1 = values[i + 1];
  142. ptr[i / 2] = val0 | (val1 << 4);
  143. }
  144. return tensor;
  145. };
  146. using Param = MatrixMul::Param;
  147. Param param;
  148. checker.set_param(param);
  149. checker.set_dtype(2, dtype::QuantizedS32(0.3f * 0.3f));
  150. checker.exect(
  151. Testcase{
  152. GenTensorValueQuint4(
  153. {8, 8}, dtype::Quantized4Asymm(0.3f, (uint8_t)8),
  154. {13, 2, 4, 13, 9, 3, 14, 14, 14, 5, 3, 3, 15,
  155. 11, 8, 8, 5, 7, 14, 15, 8, 2, 11, 1, 15, 9,
  156. 13, 14, 2, 3, 11, 11, 15, 10, 11, 0, 13, 12, 3,
  157. 11, 9, 9, 10, 5, 2, 5, 8, 4, 6, 9, 0, 0,
  158. 3, 9, 9, 8, 8, 15, 7, 5, 0, 3, 9, 10}),
  159. GenTensorValueQuint4(
  160. {8, 8}, dtype::Quantized4Asymm(0.3f, (uint8_t)8),
  161. {5, 14, 13, 11, 4, 7, 12, 12, 11, 7, 13, 10, 5,
  162. 6, 4, 2, 3, 12, 2, 2, 13, 3, 14, 0, 15, 15,
  163. 0, 2, 2, 13, 3, 14, 10, 8, 9, 11, 0, 14, 15,
  164. 4, 14, 7, 1, 6, 13, 2, 12, 5, 2, 15, 7, 11,
  165. 13, 9, 8, 10, 0, 11, 6, 10, 12, 2, 2, 12}),
  166. {}},
  167. Testcase{
  168. {},
  169. {},
  170. TensorValue(
  171. {8, 8}, dtype::QuantizedS32(0.3f * 0.3f),
  172. {-90, 120, -3, 40, -31, 58, -54, 165, -5, -19, 71,
  173. 87, -51, 24, 92, 15, 27, 62, -59, -82, -40, 91,
  174. 11, -16, -85, 138, -18, -36, 8, -25, -56, 75, -46,
  175. -34, 67, 53, -4, -83, 111, -86, -29, -17, 45, -9,
  176. 38, -22, -3, -19, -17, -95, 94, 78, 63, -35, -51,
  177. 21, -63, -14, 87, 31, 44, -53, -107, 5}),
  178. });
  179. }
  180. TEST_F(NAIVE, MATRIX_MUL_QUANTIZEDS4_4x4x16) {
  181. Checker<MatrixMul> checker(handle(), /* check_dispatch */ false);
  182. auto GenTensorValueQuint4 = [](const TensorShape& shape, dtype::QuantizedS4 dtype,
  183. const std::vector<int>& values) {
  184. TensorND tensor;
  185. tensor.layout = {shape, dtype};
  186. tensor.reset_ptr(
  187. static_cast<dt_byte*>(malloc(tensor.layout.span().dist_byte())));
  188. uint8_t* ptr = static_cast<uint8_t*>(tensor.raw_ptr());
  189. megdnn_assert(values.size() == tensor.layout.span().dist_elem());
  190. for (size_t i = 0; i < tensor.layout.span().dist_elem(); i += 2) {
  191. int val0 = values[i], val1 = values[i + 1];
  192. ptr[i / 2] = (val0 & 0xF) | (val1 << 4);
  193. }
  194. return tensor;
  195. };
  196. using Param = MatrixMul::Param;
  197. Param param;
  198. checker.set_param(param);
  199. checker.set_dtype(2, dtype::QuantizedS16(0.3f * 0.3f));
  200. checker.exect(
  201. Testcase{
  202. GenTensorValueQuint4(
  203. {8, 8}, dtype::QuantizedS4(0.3f),
  204. {-8, 7, 2, 1, 2, 3, 2, 7, 2, 5, 3, 3, 7, 4, -7, 1,
  205. -5, 7, -4, -1, -1, 2, 4, 1, 7, 2, -6, -2, -6, 3, 4, 4,
  206. -2, 2, 3, 0, 6, 5, 3, 4, -1, -1, -5, 5, 2, 5, 1, 4,
  207. 6, 2, 0, 0, 3, 2, 2, 1, -4, -3, 7, 5, 0, 3, 2, 3}),
  208. GenTensorValueQuint4(
  209. {8, 8}, dtype::QuantizedS4(0.3f),
  210. {5, -8, -7, -6, 4, 7, -5, -5, -4, 7, -3, -2, 5,
  211. 6, 4, 2, 3, -1, 2, 2, 7, 3, 6, 0, 5, 4,
  212. 0, 2, 2, 3, 3, 2, 1, -8, -7, -6, 0, -5, -4,
  213. 4, -3, 7, 1, 6, -2, 2, -1, 5, 2, 0, 7, 6,
  214. 5, 4, 3, 2, 0, 0, 1, 0, 5, 2, 2, 6}),
  215. {}},
  216. Testcase{
  217. {},
  218. {},
  219. TensorValue(
  220. {8, 8}, dtype::QuantizedS16(0.3f * 0.3f),
  221. {-60, 120, 49, 58, 58, 13, 92, 125, -5, 0, -116,
  222. -70, 22, 9, -14, 46, -69, 111, 44, 48, 6, 19,
  223. 42, 57, -8, 25, 10, 16, 26, 97, -28, -12, -12,
  224. 14, 2, 26, 48, 7, 24, 93, -2, 45, 2, 32,
  225. -19, -1, -16, 72, 23, -44, -52, -34, 45, 53, -28,
  226. 6, 33, 45, 71, 84, 47, 10, 74, 61})
  227. });
  228. }
  229. TEST_F(NAIVE, MATRIX_MUL_QUANTIZED8x8x32) {
  230. Checker<MatrixMul> checker(handle(), /* check_dispatch */ false);
  231. MatrixMul::Param param;
  232. param.transposeA = false;
  233. param.transposeB = false;
  234. checker.set_param(param).exect(
  235. Testcase{
  236. TensorValue(
  237. {4, 7}, dtype::Quantized8Asymm(0.1f, (uint8_t)128),
  238. {6, 97, 210, 47, 213, 246, 92, 121, 132, 133,
  239. 37, 31, 87, 71, 0, 5, 198, 11, 97, 141,
  240. 222, 166, 76, 212, 190, 108, 245, 143}),
  241. TensorValue(
  242. {7, 5}, dtype::Quantized8Asymm(0.2f, (uint8_t)233),
  243. {89, 207, 79, 135, 43, 29, 235, 171, 40, 78, 119, 145,
  244. 254, 162, 184, 139, 248, 214, 201, 183, 127, 75, 48, 200,
  245. 96, 109, 63, 60, 100, 120, 111, 182, 150, 227, 92}),
  246. {}},
  247. Testcase{
  248. {},
  249. {},
  250. TensorValue(
  251. {4, 5}, dtype::QuantizedS32(0.1f * 0.2f),
  252. {2908, -36975, -9180, -3574, 8114, 30496, 23588,
  253. 32433, 11467, 30974, 36748, -6939, 26715, 33787,
  254. 35329, -24486, -25049, -19828, -16627, -18972})});
  255. param.transposeA = true;
  256. checker.set_param(param).exect(
  257. Testcase{
  258. TensorValue(
  259. {2, 1}, dtype::Quantized8Asymm(0.7f, (uint8_t)128),
  260. {129, 129}),
  261. TensorValue(
  262. {2, 1}, dtype::Quantized8Asymm(0.4f, (uint8_t)128),
  263. {129, 129}),
  264. {}},
  265. Testcase{
  266. {},
  267. {},
  268. TensorValue({1, 1}, dtype::QuantizedS32(0.7f * 0.4f), {2})});
  269. }
  270. TEST_F(NAIVE, MATRIX_MUL_MK4) {
  271. run_matmul_mk_format(
  272. handle(), param::MatrixMul::Format::MK4, dtype::Float32(), dtype::Float32(),
  273. dtype::Float32());
  274. }
  275. TEST_F(NAIVE, MATRIX_MUL_MK8) {
  276. run_matmul_mk_format(
  277. handle(), param::MatrixMul::Format::MK8, dtype::Int16(), dtype::Int16(),
  278. dtype::Int32());
  279. }
  280. TEST_F(NAIVE, MATRIX_MUL_MK4_DOT) {
  281. run_matmul_mk_format(
  282. handle(), param::MatrixMul::Format::MK4_DOT, dtype::Int8(), dtype::Int8(),
  283. dtype::Int32());
  284. }
  285. TEST_F(NAIVE, MATRIX_MUL_BFLOAT16) {
  286. Checker<MatrixMul> checker(handle(), /* check_dispatch */ false);
  287. MatrixMul::Param param, fp32_param;
  288. fp32_param = param;
  289. param.compute_mode = param::MatrixMul::ComputeMode::FLOAT32;
  290. checker.set_param(param);
  291. checker.set_dtype(0, dtype::BFloat16());
  292. checker.set_dtype(1, dtype::BFloat16());
  293. checker.set_dtype(2, dtype::BFloat16());
  294. auto extra_impl = extra_impl_helper<MatrixMul>(handle(), fp32_param);
  295. checker.set_extra_opr_impl(extra_impl);
  296. checker.set_epsilon(1.5e-2);
  297. UniformFloatRNG frng{1e-2, 5.f};
  298. checker.set_rng(0, &frng);
  299. checker.set_rng(1, &frng);
  300. checker.execs({{8, 8}, {8, 8}, {}});
  301. param.compute_mode = param::MatrixMul::ComputeMode::DEFAULT;
  302. checker.set_param(param);
  303. checker.execs({{8, 8}, {8, 8}, {}});
  304. }
  305. // vim: syntax=cpp.doxygen