You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_mul.h 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. /**
  2. * \file dnn/test/common/matrix_mul.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include <cstddef>
  13. #include <vector>
  14. #include "megdnn/dtype.h"
  15. #include "megdnn/handle.h"
  16. #include "megdnn/opr_param_defs.h"
  17. #include "megdnn/oprs.h"
  18. #include "test/common/checker.h"
  19. namespace megdnn {
  20. namespace test {
  21. namespace matrix_mul {
  22. // mask & 1 denotes transposeA; mask & 2 denotes transposeB
  23. struct TestArg {
  24. constexpr static size_t UNSET_STRIDE_VAL = static_cast<size_t>(-1);
  25. size_t m, n, k, mask;
  26. size_t A_stride, B_stride, C_stride, b;
  27. size_t A_batch_stride, B_batch_stride, C_batch_stride;
  28. // stride = 0 means the default stride, the dim is contiguous, i.e. the
  29. // stride value which makes tensor compact.
  30. TestArg(size_t m, size_t n, size_t k, size_t mask,
  31. size_t A_stride = UNSET_STRIDE_VAL,
  32. size_t B_stride = UNSET_STRIDE_VAL,
  33. size_t C_stride = UNSET_STRIDE_VAL, size_t b = 1,
  34. size_t A_batch_stride = UNSET_STRIDE_VAL,
  35. size_t B_batch_stride = UNSET_STRIDE_VAL,
  36. size_t C_batch_stride = UNSET_STRIDE_VAL)
  37. : m{m},
  38. n{n},
  39. k{k},
  40. mask{mask},
  41. A_stride{A_stride},
  42. B_stride{B_stride},
  43. C_stride{C_stride},
  44. b{b},
  45. A_batch_stride{A_batch_stride},
  46. B_batch_stride{B_batch_stride},
  47. C_batch_stride{C_batch_stride} {}
  48. };
  49. std::vector<TestArg> get_matmul_args_no_mask();
  50. std::vector<TestArg> get_matmul_args_mask(uint8_t mask);
  51. std::vector<TestArg> get_matmul_args();
  52. std::vector<TestArg> get_matmul_args_split_k();
  53. std::vector<TestArg> get_batched_matmul_args_mask(uint8_t mask);
  54. std::vector<TestArg> get_batched_matmul_args();
  55. std::vector<TestArg> get_batched_matmul_broadcast_args();
  56. std::vector<TestArg> get_batched_matmul_broadcast_args_mask(uint8_t mask);
  57. std::vector<TestArg> get_matmul_mk_packed_args(size_t nbase);
  58. std::vector<TestArg> get_batched_matmul_args_cublaslt();
  59. std::vector<TestArg> get_batched_matmul_args_int8x8x32();
  60. using TestArgFilterFunc = std::function<bool(const TestArg&)>;
  61. template <typename Opr = megdnn::MatrixMul>
  62. void check_matrix_mul(
  63. DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle,
  64. const ExecutionPolicyAlgoName& algo = {"", {}},
  65. param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT,
  66. size_t nbase = 8, float eps = 1e-3, std::vector<TestArg>&& args = {},
  67. bool force_deduce_dst = true);
  68. void check_matrix_mul(
  69. DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle,
  70. const ExecutionPolicyAlgoName& algo = {"", {}},
  71. param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT,
  72. size_t nbase = 8, float eps = 1e-3, bool force_deduce_dst = true);
  73. void check_batched_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype,
  74. Handle* handle,
  75. const ExecutionPolicyAlgoName& algo = {"", {}},
  76. float eps = 1e-3,
  77. std::vector<TestArg>&& args = {},
  78. bool force_deduce_dst = true);
  79. #if MEGDNN_WITH_BENCHMARK
  80. std::vector<TestArg> get_benchmark_matmul_args();
  81. std::vector<TestArg> get_benchmark_matmul_mk_packed_args(size_t nbase);
  82. //! benchmark performance with float matmul
  83. void benchmark_with_contrast(
  84. Handle* handle, const std::vector<TestArg>& args, DType A_dtype,
  85. DType B_dtype, DType C_dtype, const char* algo = nullptr,
  86. param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT,
  87. DType contrast_A_dtype = dtype::Float32{},
  88. DType contrast_B_dtype = dtype::Float32{},
  89. DType contrast_C_dtype = dtype::Float32{},
  90. const char* contrast_algo = nullptr,
  91. param::MatrixMul::Format contrast_format =
  92. param::MatrixMul::Format::DEFAULT);
  93. #endif
  94. } // namespace matrix_mul
  95. } // namespace test
  96. } // namespace megdnn
  97. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台