You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_mul.h 4.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. #pragma once
  2. #include <cstddef>
  3. #include <vector>
  4. #include "megdnn/dtype.h"
  5. #include "megdnn/handle.h"
  6. #include "megdnn/opr_param_defs.h"
  7. #include "megdnn/oprs.h"
  8. #include "test/common/checker.h"
  9. namespace megdnn {
  10. namespace test {
  11. namespace matrix_mul {
  12. // mask & 1 denotes transposeA; mask & 2 denotes transposeB
  13. struct TestArg {
  14. constexpr static size_t UNSET_STRIDE_VAL = static_cast<size_t>(-1);
  15. size_t m, n, k, mask;
  16. size_t A_stride, B_stride, C_stride, b;
  17. size_t A_batch_stride, B_batch_stride, C_batch_stride;
  18. // stride = 0 means the default stride, the dim is contiguous, i.e. the
  19. // stride value which makes tensor compact.
  20. TestArg(size_t m, size_t n, size_t k, size_t mask,
  21. size_t A_stride = UNSET_STRIDE_VAL, size_t B_stride = UNSET_STRIDE_VAL,
  22. size_t C_stride = UNSET_STRIDE_VAL, size_t b = 1,
  23. size_t A_batch_stride = UNSET_STRIDE_VAL,
  24. size_t B_batch_stride = UNSET_STRIDE_VAL,
  25. size_t C_batch_stride = UNSET_STRIDE_VAL)
  26. : m{m},
  27. n{n},
  28. k{k},
  29. mask{mask},
  30. A_stride{A_stride},
  31. B_stride{B_stride},
  32. C_stride{C_stride},
  33. b{b},
  34. A_batch_stride{A_batch_stride},
  35. B_batch_stride{B_batch_stride},
  36. C_batch_stride{C_batch_stride} {}
  37. };
  38. std::vector<TestArg> get_matmul_args_no_mask();
  39. std::vector<TestArg> get_matmul_args_mask(uint8_t mask);
  40. std::vector<TestArg> get_matmul_args();
  41. std::vector<TestArg> get_matmul_args_split_k();
  42. std::vector<TestArg> get_batched_matmul_args_mask(uint8_t mask);
  43. std::vector<TestArg> get_batched_matmul_args();
  44. std::vector<TestArg> get_batched_matmul_broadcast_args();
  45. std::vector<TestArg> get_batched_matmul_broadcast_args_mask(uint8_t mask);
  46. std::vector<TestArg> get_matmul_mk_packed_args(size_t nbase);
  47. std::vector<TestArg> get_batched_matmul_args_cublaslt();
  48. std::vector<TestArg> get_batched_matmul_args_int8x8x32();
  49. using TestArgFilterFunc = std::function<bool(const TestArg&)>;
  50. template <typename Opr = megdnn::MatrixMul>
  51. void check_matrix_mul(
  52. DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle,
  53. const ExecutionPolicyAlgoName& algo = {"", {}},
  54. param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT,
  55. size_t nbase = 8, float eps = 1e-3, std::vector<TestArg>&& args = {},
  56. bool force_deduce_dst = true,
  57. param::MatrixMul::ComputeMode compute_mode =
  58. param::MatrixMul::ComputeMode::DEFAULT);
  59. void check_matrix_mul(
  60. DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle,
  61. const ExecutionPolicyAlgoName& algo = {"", {}},
  62. param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT,
  63. size_t nbase = 8, float eps = 1e-3, bool force_deduce_dst = true);
  64. void check_batched_matrix_mul(
  65. DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle,
  66. const ExecutionPolicyAlgoName& algo = {"", {}}, float eps = 1e-3,
  67. std::vector<TestArg>&& args = {}, bool force_deduce_dst = true);
  68. #if MEGDNN_WITH_BENCHMARK
  69. std::vector<TestArg> get_benchmark_matmul_args();
  70. std::vector<TestArg> get_benchmark_matmul_mk_packed_args(size_t nbase);
  71. //! benchmark performance with float matmul
  72. void benchmark_with_contrast(
  73. Handle* handle, const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  74. DType C_dtype, const char* algo = nullptr,
  75. param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT,
  76. DType contrast_A_dtype = dtype::Float32{},
  77. DType contrast_B_dtype = dtype::Float32{},
  78. DType contrast_C_dtype = dtype::Float32{}, const char* contrast_algo = nullptr,
  79. param::MatrixMul::Format contrast_format = param::MatrixMul::Format::DEFAULT);
  80. void benchmark_single_algo(
  81. Handle* handle, const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  82. DType C_dtype, const char* algo = nullptr,
  83. param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT);
  84. #endif
  85. } // namespace matrix_mul
  86. } // namespace test
  87. } // namespace megdnn
  88. // vim: syntax=cpp.doxygen