You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batched_matrix_mul.cpp 1.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. #include "hcc_detail/hcc_defs_prologue.h"
  2. #include "test/rocm/fixture.h"
  3. #include "test/common/checker.h"
  4. #include "test/common/rng.h"
  5. namespace megdnn {
  6. namespace test {
  7. TEST_F(ROCM, BATCHED_MATRIX_MUL) {
  8. Checker<BatchedMatrixMul> checker(handle_rocm());
  9. checker.set_epsilon(1e-2);
  10. using Param = MatrixMul::Param;
  11. size_t b = 9, m = 10, n = 11, k = 12;
  12. std::vector<DType> dtypes{DNN_INC_FLOAT16(dtype::Float16() MEGDNN_COMMA)
  13. dtype::Float32()};
  14. for (auto dtype : dtypes)
  15. for (unsigned mask = 0; mask < 4; ++mask) {
  16. Param param;
  17. param.transposeA = mask & 1;
  18. param.transposeB = mask & 2;
  19. TensorShape A, B;
  20. if (param.transposeA)
  21. A = TensorShape{b, k, m};
  22. else
  23. A = TensorShape{b, m, k};
  24. if (param.transposeB)
  25. B = TensorShape{b, n, k};
  26. else
  27. B = TensorShape{b, k, n};
  28. checker.set_param(param)
  29. .set_dtype(0, dtype)
  30. .set_dtype(1, dtype)
  31. .set_dtype(2, dtype)
  32. .execs({A, B, {}});
  33. }
  34. }
  35. } // namespace test
  36. } // namespace megdnn
  37. // vim: syntax=cpp.doxygen