You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_mul.cpp 2.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. #include "hcc_detail/hcc_defs_prologue.h"
  2. #include "test/rocm/fixture.h"
  3. #include "test/common/checker.h"
  4. #include "test/common/matrix_mul.h"
  5. #include "src/rocm/utils.h"
  6. namespace megdnn {
  7. namespace test {
  8. TEST_F(ROCM, MATRIX_MUL) {
  9. Checker<MatrixMul> checker(handle_rocm());
  10. using Param = MatrixMul::Param;
  11. size_t m = 12, n = 16, k = 20;
  12. //! result error for Int8x8x32, not test correctness
  13. std::vector<DType> dtypes{DNN_INC_FLOAT16(dtype::Float16() MEGDNN_COMMA)
  14. dtype::Float32() /*, dtype::Int32()*/};
  15. for (auto dtype : dtypes) {
  16. for (unsigned mask = 0; mask < 4; ++mask) {
  17. Param param;
  18. param.transposeA = mask & 1;
  19. param.transposeB = mask & 2;
  20. DType stype = dtype == dtype::Int32() ? dtype::Int8() : dtype;
  21. TensorShape A, B;
  22. if (param.transposeA)
  23. A = TensorShape{k, m};
  24. else
  25. A = TensorShape{m, k};
  26. if (param.transposeB)
  27. B = TensorShape{n, k};
  28. else
  29. B = TensorShape{k, n};
  30. checker.set_param(param)
  31. .set_dtype(0, stype)
  32. .set_dtype(1, stype)
  33. .set_dtype(2, dtype)
  34. .set_epsilon(
  35. DNN_FLOAT16_SELECT(dtype == dtype::Float16(), false) ? 5e-2
  36. : 5e-3)
  37. .execs({A, B, {}});
  38. }
  39. }
  40. // general tests
  41. auto args = matrix_mul::get_matmul_args();
  42. for (auto arg : args) {
  43. auto m = arg.m, n = arg.n, k = arg.k;
  44. auto mask = arg.mask;
  45. Param param;
  46. param.transposeA = mask & 1;
  47. param.transposeB = mask & 2;
  48. TensorShape AS, BS, CS;
  49. if (param.transposeA)
  50. AS = TensorShape{k, m};
  51. else
  52. AS = TensorShape{m, k};
  53. if (param.transposeB)
  54. BS = TensorShape{n, k};
  55. else
  56. BS = TensorShape{k, n};
  57. CS = TensorShape{m, n};
  58. TensorLayout AL, BL, CL;
  59. if (arg.A_stride == 0) {
  60. AL = TensorLayout(AS, dtype::Float32());
  61. } else {
  62. AL = TensorLayout(AS, {ptrdiff_t(arg.A_stride), 1}, dtype::Float32());
  63. }
  64. if (arg.B_stride == 0) {
  65. BL = TensorLayout(BS, dtype::Float32());
  66. } else {
  67. BL = TensorLayout(BS, {ptrdiff_t(arg.B_stride), 1}, dtype::Float32());
  68. }
  69. if (arg.C_stride == 0) {
  70. CL = TensorLayout(CS, dtype::Float32());
  71. } else {
  72. CL = TensorLayout(CS, {ptrdiff_t(arg.C_stride), 1}, dtype::Float32());
  73. }
  74. checker.set_param(param).execl({AL, BL, CL});
  75. }
  76. }
  77. } // namespace test
  78. } // namespace megdnn
  79. // vim: syntax=cpp.doxygen