You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

blas.cpp 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. /**
  2. * \file dnn/src/rocm/batched_matrix_mul/Blas.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/rocm/batched_matrix_mul/algos.h"
  13. #include "hcc_detail/hcc_defs_prologue.h"
  14. #include "src/rocm/handle.h"
  15. #include "src/rocm/utils.h"
  16. using namespace megdnn;
  17. using namespace rocm;
  18. bool BatchedMatrixMulForwardImpl::AlgoBlas::is_available(
  19. const SizeArgs& args) const {
  20. if (args.opr->param().format != param::MatrixMul::Format::DEFAULT)
  21. return false;
  22. if (args.layout_a.dtype == dtype::Float32() ||
  23. args.layout_a.dtype == dtype::Float16()) {
  24. return true;
  25. }
  26. return false;
  27. }
  28. void BatchedMatrixMulForwardImpl::AlgoBlas::exec(const ExecArgs& args) const {
  29. auto batch = args.layout_a.shape[0];
  30. auto m = args.layout_c.shape[1], n = args.layout_c.shape[2];
  31. auto k = args.layout_a.shape[args.opr->param().transposeA ? 1 : 2];
  32. auto&& handle = concrete_handle(args.opr->handle());
  33. auto rocblas_handle_ = handle->get_rocblas_handle();
  34. auto sgemm = [&]() {
  35. auto zero = handle->zero_device();
  36. auto one = handle->one_device();
  37. rocblas_check(rocblas_sgemm_strided_batched(
  38. rocblas_handle_,
  39. args.opr->param().transposeB ? rocblas_operation_transpose
  40. : rocblas_operation_none,
  41. args.opr->param().transposeA ? rocblas_operation_transpose
  42. : rocblas_operation_none,
  43. n, m, k, one, args.tensor_b.ptr<dt_float32>(),
  44. (rocblas_int)(args.layout_b.stride[1]),
  45. (rocblas_int)(args.layout_b.stride[0]),
  46. args.tensor_a.ptr<dt_float32>(),
  47. (rocblas_int)(args.layout_a.stride[1]),
  48. (rocblas_int)(args.layout_a.stride[0]), zero,
  49. args.tensor_c.ptr<dt_float32>(),
  50. (rocblas_int)(args.layout_c.stride[1]),
  51. (rocblas_int)(args.layout_c.stride[0]), (rocblas_int)(batch)));
  52. };
  53. #if !MEGDNN_DISABLE_FLOAT16
  54. //! used for FLOAT_IO16xC32, not tested
  55. auto gemm_ex = [&]() {
  56. auto zero = handle->zero_device();
  57. auto one = handle->one_device();
  58. //! These two arguments for future use, see
  59. //! https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/src/blas_ex/rocblas_gemm_ex.cpp
  60. int32_t solution_index = 0;
  61. uint32_t flags = 1;
  62. size_t ws_size = 0;
  63. rocblas_check(rocblas_gemm_strided_batched_ex(
  64. rocblas_handle_,
  65. args.opr->param().transposeB ? rocblas_operation_transpose
  66. : rocblas_operation_none,
  67. args.opr->param().transposeA ? rocblas_operation_transpose
  68. : rocblas_operation_none,
  69. n, m, k, one, args.tensor_b.raw_ptr, rocblas_datatype_i8_r,
  70. args.layout_b.stride[1], args.layout_b.stride[0],
  71. args.tensor_a.raw_ptr, rocblas_datatype_i8_r,
  72. args.layout_a.stride[1], args.layout_a.stride[0], zero,
  73. args.tensor_c.raw_ptr, rocblas_datatype_i32_r,
  74. args.layout_c.stride[1], args.layout_c.stride[0],
  75. args.tensor_c.raw_ptr, rocblas_datatype_i32_r,
  76. args.layout_c.stride[1], args.layout_c.stride[0], batch,
  77. rocblas_datatype_i32_r, rocblas_gemm_algo_standard,
  78. solution_index, flags, &ws_size, nullptr));
  79. MEGDNN_MARK_USED_VAR(ws_size);
  80. };
  81. auto hgemm = [&]() {
  82. auto one_half = handle->one_device_h();
  83. auto zero_half = handle->zero_device_h();
  84. rocblas_check(rocblas_hgemm_strided_batched(
  85. rocblas_handle_,
  86. args.opr->param().transposeB ? rocblas_operation_transpose
  87. : rocblas_operation_none,
  88. args.opr->param().transposeA ? rocblas_operation_transpose
  89. : rocblas_operation_none,
  90. n, m, k, reinterpret_cast<const rocblas_half*>(one_half),
  91. static_cast<const rocblas_half*>(args.tensor_b.raw_ptr),
  92. args.layout_b.stride[1], args.layout_b.stride[0],
  93. static_cast<const rocblas_half*>(args.tensor_a.raw_ptr),
  94. args.layout_a.stride[1], args.layout_a.stride[0],
  95. reinterpret_cast<const rocblas_half*>(zero_half),
  96. static_cast<rocblas_half*>(args.tensor_c.raw_ptr),
  97. args.layout_c.stride[1], args.layout_c.stride[0], batch));
  98. };
  99. #endif
  100. if (args.opr->param().compute_mode == Param::ComputeMode::DEFAULT) {
  101. if (args.layout_a.dtype == dtype::Float32()) {
  102. sgemm();
  103. }
  104. #if !MEGDNN_DISABLE_FLOAT16
  105. else {
  106. megdnn_assert(args.layout_a.dtype == dtype::Float16(),
  107. "invalid matmul data type");
  108. hgemm();
  109. }
  110. #endif
  111. }
  112. #if !MEGDNN_DISABLE_FLOAT16
  113. else if (args.opr->param().compute_mode == Param::ComputeMode::FLOAT32) {
  114. megdnn_assert(args.layout_b.dtype == dtype::Float16() &&
  115. args.layout_c.dtype == dtype::Float16() &&
  116. args.layout_a.dtype == dtype::Float16(),
  117. "DataType::FLOAT_IO16xC32 is supported, when dtype of A, "
  118. "B, C are all Float16");
  119. gemm_ex();
  120. }
  121. #endif
  122. else {
  123. megdnn_throw("Unsupported data_type of matrix mul on rocm.");
  124. }
  125. }
  126. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台