|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140 |
- /**
- * \file dnn/src/rocm/batched_matrix_mul/Blas.cpp
- * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- *
- * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
- * implied.
- */
-
- #include "src/rocm/batched_matrix_mul/algos.h"
-
- #include "hcc_detail/hcc_defs_prologue.h"
- #include "src/rocm/handle.h"
- #include "src/rocm/utils.h"
-
- using namespace megdnn;
- using namespace rocm;
-
- bool BatchedMatrixMulForwardImpl::AlgoBlas::is_available(
- const SizeArgs& args) const {
- if (args.opr->param().format != param::MatrixMul::Format::DEFAULT)
- return false;
- if (args.layout_a.dtype == dtype::Float32() ||
- args.layout_a.dtype == dtype::Float16()) {
- return true;
- }
- return false;
- }
-
- void BatchedMatrixMulForwardImpl::AlgoBlas::exec(const ExecArgs& args) const {
- auto batch = args.layout_a.shape[0];
- auto m = args.layout_c.shape[1], n = args.layout_c.shape[2];
- auto k = args.layout_a.shape[args.opr->param().transposeA ? 1 : 2];
- auto&& handle = concrete_handle(args.opr->handle());
- auto rocblas_handle_ = handle->get_rocblas_handle();
-
- auto sgemm = [&]() {
- auto zero = handle->zero_device();
- auto one = handle->one_device();
- rocblas_check(rocblas_sgemm_strided_batched(
- rocblas_handle_,
- args.opr->param().transposeB ? rocblas_operation_transpose
- : rocblas_operation_none,
- args.opr->param().transposeA ? rocblas_operation_transpose
- : rocblas_operation_none,
- n, m, k, one, args.tensor_b.ptr<dt_float32>(),
- (rocblas_int)(args.layout_b.stride[1]),
- (rocblas_int)(args.layout_b.stride[0]),
- args.tensor_a.ptr<dt_float32>(),
- (rocblas_int)(args.layout_a.stride[1]),
- (rocblas_int)(args.layout_a.stride[0]), zero,
- args.tensor_c.ptr<dt_float32>(),
- (rocblas_int)(args.layout_c.stride[1]),
- (rocblas_int)(args.layout_c.stride[0]), (rocblas_int)(batch)));
-
- };
-
- #if !MEGDNN_DISABLE_FLOAT16
- //! used for FLOAT_IO16xC32, not tested
- auto gemm_ex = [&]() {
- auto zero = handle->zero_device();
- auto one = handle->one_device();
- //! These two arguments for future use, see
- //! https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/src/blas_ex/rocblas_gemm_ex.cpp
- int32_t solution_index = 0;
- uint32_t flags = 1;
- size_t ws_size = 0;
-
- rocblas_check(rocblas_gemm_strided_batched_ex(
- rocblas_handle_,
- args.opr->param().transposeB ? rocblas_operation_transpose
- : rocblas_operation_none,
- args.opr->param().transposeA ? rocblas_operation_transpose
- : rocblas_operation_none,
- n, m, k, one, args.tensor_b.raw_ptr, rocblas_datatype_i8_r,
- args.layout_b.stride[1], args.layout_b.stride[0],
- args.tensor_a.raw_ptr, rocblas_datatype_i8_r,
- args.layout_a.stride[1], args.layout_a.stride[0], zero,
- args.tensor_c.raw_ptr, rocblas_datatype_i32_r,
- args.layout_c.stride[1], args.layout_c.stride[0],
- args.tensor_c.raw_ptr, rocblas_datatype_i32_r,
- args.layout_c.stride[1], args.layout_c.stride[0], batch,
- rocblas_datatype_i32_r, rocblas_gemm_algo_standard,
- solution_index, flags, &ws_size, nullptr));
-
- MEGDNN_MARK_USED_VAR(ws_size);
- };
-
- auto hgemm = [&]() {
- auto one_half = handle->one_device_h();
- auto zero_half = handle->zero_device_h();
- rocblas_check(rocblas_hgemm_strided_batched(
- rocblas_handle_,
- args.opr->param().transposeB ? rocblas_operation_transpose
- : rocblas_operation_none,
- args.opr->param().transposeA ? rocblas_operation_transpose
- : rocblas_operation_none,
- n, m, k, reinterpret_cast<const rocblas_half*>(one_half),
- static_cast<const rocblas_half*>(args.tensor_b.raw_ptr),
- args.layout_b.stride[1], args.layout_b.stride[0],
- static_cast<const rocblas_half*>(args.tensor_a.raw_ptr),
- args.layout_a.stride[1], args.layout_a.stride[0],
- reinterpret_cast<const rocblas_half*>(zero_half),
- static_cast<rocblas_half*>(args.tensor_c.raw_ptr),
- args.layout_c.stride[1], args.layout_c.stride[0], batch));
-
- };
- #endif
-
- if (args.opr->param().compute_mode == Param::ComputeMode::DEFAULT) {
- if (args.layout_a.dtype == dtype::Float32()) {
- sgemm();
- }
- #if !MEGDNN_DISABLE_FLOAT16
- else {
- megdnn_assert(args.layout_a.dtype == dtype::Float16(),
- "invalid matmul data type");
- hgemm();
- }
- #endif
- }
- #if !MEGDNN_DISABLE_FLOAT16
- else if (args.opr->param().compute_mode == Param::ComputeMode::FLOAT32) {
- megdnn_assert(args.layout_b.dtype == dtype::Float16() &&
- args.layout_c.dtype == dtype::Float16() &&
- args.layout_a.dtype == dtype::Float16(),
- "DataType::FLOAT_IO16xC32 is supported, when dtype of A, "
- "B, C are all Float16");
- gemm_ex();
- }
- #endif
- else {
- megdnn_throw("Unsupported data_type of matrix mul on rocm.");
- }
- }
-
- // vim: syntax=cpp.doxygen
|