GitOrigin-RevId: dea03a0f7a
tags/v1.3.0
@@ -0,0 +1,107 @@ | |||
/** | |||
* \file dnn/src/fallback/batched_matrix_mul/algos.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/fallback/batched_matrix_mul/algos.h" | |||
#include "src/common/algo_base.h" | |||
#include "src/naive/handle.h" | |||
using namespace megdnn; | |||
using namespace fallback; | |||
BatchedMatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||
all_algos.push_back(&algo_default); | |||
for (auto&& algo : all_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
} | |||
BatchedMatrixMulForwardImpl::AlgoPack BatchedMatrixMulForwardImpl::sm_algo_pack; | |||
MEGDNN_DEF_GET_ALGO_FROM_DESC(BatchedMatrixMulForwardImpl) | |||
BatchedMatrixMulForwardImpl::AlgoBase::SizeArgs::SizeArgs( | |||
BatchedMatrixMulForwardImpl* o, const TensorLayout& A, | |||
const TensorLayout& B, const TensorLayout& C) | |||
: opr{o}, layout_a{A}, layout_b{B}, layout_c{C} {} | |||
BatchedMatrixMulForwardImpl::AlgoBase::ExecArgs::ExecArgs( | |||
BatchedMatrixMulForwardImpl* opr, _megdnn_tensor_in A, | |||
_megdnn_tensor_in B, _megdnn_tensor_out C, _megdnn_workspace workspace) | |||
: SizeArgs(opr, A.layout, B.layout, C.layout), | |||
tensor_a{A}, | |||
tensor_b{B}, | |||
tensor_c{C}, | |||
workspace{workspace} {} | |||
std::string BatchedMatrixMulForwardImpl::AlgoBase::SizeArgs::to_string() const { | |||
auto&& param = opr->param(); | |||
size_t m = layout_a.shape[0], n = layout_b.shape[1], | |||
k = layout_a.shape[param.transposeA ? 0 : 1]; | |||
MEGDNN_MARK_USED_VAR(m); | |||
MEGDNN_MARK_USED_VAR(n); | |||
MEGDNN_MARK_USED_VAR(k); | |||
return megdnn_mangle(ssprintf( | |||
"A={%zux%zu},B={%zux%zu},C={%zux%zu},Transpose A=%d,Transpose " | |||
"B=%d,ldA=%zu,ldB=%zu,ldC=%zu", | |||
m, k, k, n, m, n, param.transposeA, param.transposeB, | |||
layout_a.stride[0], layout_b.stride[0], layout_c.stride[0])); | |||
} | |||
/* ===================== default algo ===================== */ | |||
size_t BatchedMatrixMulForwardImpl::AlgoDefault::get_workspace_in_bytes( | |||
const SizeArgs& args) const { | |||
auto opr = inplace_cpu_handle()->create_operator<MatrixMul>(); | |||
auto A_ = args.layout_a.remove_axis(0), B_ = args.layout_b.remove_axis(0), | |||
C_ = args.layout_c.remove_axis(0); | |||
opr->param() = args.opr->param(); | |||
return opr->get_workspace_in_bytes(A_, B_, C_); | |||
} | |||
void BatchedMatrixMulForwardImpl::AlgoDefault::exec( | |||
const ExecArgs& args) const { | |||
//! As megbrain may modify param when checking all transpose situations, so | |||
//! here we should copy the param when dispatching kern | |||
auto param = args.opr->param(); | |||
auto kern = [args, param]() { | |||
auto N = args.layout_a.shape[0]; | |||
TensorND A_, B_, C_; | |||
A_.raw_ptr = args.tensor_a.raw_ptr; | |||
A_.layout = args.layout_a.remove_axis(0); | |||
B_.raw_ptr = args.tensor_b.raw_ptr; | |||
B_.layout = args.layout_b.remove_axis(0); | |||
C_.raw_ptr = args.tensor_c.raw_ptr; | |||
C_.layout = args.layout_c.remove_axis(0); | |||
auto Astrd = args.layout_a.dtype.size() * args.layout_a.stride[0], | |||
Bstrd = args.layout_b.dtype.size() * args.layout_b.stride[0], | |||
Cstrd = args.layout_c.dtype.size() * args.layout_c.stride[0]; | |||
auto advance_ptr = [](TensorND& dest, ptrdiff_t d) { | |||
dest.raw_ptr = | |||
static_cast<void*>(static_cast<dt_byte*>(dest.raw_ptr) + d); | |||
}; | |||
auto opr = inplace_cpu_handle()->create_operator<MatrixMul>(); | |||
opr->param() = param; | |||
rep(n, N) { | |||
opr->exec(A_, B_, C_, args.workspace); | |||
advance_ptr(A_, Astrd); | |||
advance_ptr(B_, Bstrd); | |||
advance_ptr(C_, Cstrd); | |||
} | |||
}; | |||
static_cast<naive::HandleImpl*>(args.opr->handle())->dispatch_kern(kern); | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,109 @@ | |||
/** | |||
* \file dnn/src/fallback/batched_matrix_mul/algos.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
#include "src/common/algo_base.h" | |||
#include "src/common/metahelper.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/batched_matrix_mul/opr_impl.h" | |||
#include <memory> | |||
#include <unordered_map> | |||
namespace megdnn { | |||
namespace fallback { | |||
/*! | |||
* \brief base class for matrix mul algos | |||
* | |||
*/ | |||
class BatchedMatrixMulForwardImpl::AlgoBase : public Algorithm { | |||
protected: | |||
~AlgoBase() = default; | |||
public: | |||
enum class AlgoType : uint32_t { | |||
fallback_BLAS, | |||
}; | |||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::FALLBACK; } | |||
struct SizeArgs { | |||
BatchedMatrixMulForwardImpl* opr; | |||
TensorLayout layout_a, layout_b, layout_c; | |||
std::string to_string() const; | |||
SizeArgs(BatchedMatrixMulForwardImpl* opr, const TensorLayout& A, | |||
const TensorLayout& B, const TensorLayout& C); | |||
}; | |||
struct ExecArgs : public SizeArgs { | |||
TensorND tensor_a, tensor_b, tensor_c; | |||
Workspace workspace; | |||
ExecArgs(BatchedMatrixMulForwardImpl* opr, _megdnn_tensor_in A, | |||
_megdnn_tensor_in B, _megdnn_tensor_out C, | |||
_megdnn_workspace workspace); | |||
}; | |||
virtual bool is_available(const SizeArgs& args) const = 0; | |||
virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||
virtual void exec(const ExecArgs&) const = 0; | |||
bool is_available_wk(const SizeArgs& args, size_t limit) const { | |||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
} | |||
bool is_available_reproducible( | |||
const SizeArgs& args, bool reproducible = true, | |||
size_t limit = std::numeric_limits<size_t>::max()) const { | |||
return (!reproducible || is_reproducible()) && | |||
is_available_wk(args, limit); | |||
} | |||
AlgoBase& check_workspace(const SizeArgs& args, | |||
const Workspace& workspace) { | |||
auto req = get_workspace_in_bytes(args); | |||
megdnn_assert( | |||
req <= workspace.size, | |||
"matrix mul fwd algo %s: required workspace %zu bytes, got %zu", | |||
name(), req, workspace.size); | |||
return *this; | |||
} | |||
}; | |||
class BatchedMatrixMulForwardImpl::AlgoDefault final : public AlgoBase { | |||
public: | |||
AlgoDefault() = default; | |||
bool is_available(const SizeArgs&) const override { return true; } | |||
size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override; | |||
const char* name() const override { return "DEFAULT"; } | |||
virtual void exec(const ExecArgs&) const override; | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(fallback_BLAS) | |||
}; | |||
class BatchedMatrixMulForwardImpl::AlgoPack : NonCopyableObj { | |||
private: | |||
AlgoBase::Mapper m_all_algos_map; | |||
public: | |||
AlgoPack(); | |||
AlgoDefault algo_default; | |||
std::vector<AlgoBase*> all_algos; | |||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
}; | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -6,67 +6,61 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "./opr_impl.h" | |||
#include "src/naive/handle.h" | |||
#include "./algos.h" | |||
#include "hcc_detail/hcc_defs_prologue.h" | |||
#include "src/common/algo_chooser.h" | |||
#include "src/common/utils.cuh" | |||
#include "src/fallback/handle.h" | |||
using namespace megdnn; | |||
using namespace fallback; | |||
BatchedMatrixMulImpl::BatchedMatrixMulImpl(Handle *handle): | |||
BatchedMatrixMulForwardImpl(handle), | |||
m_storage(new CpuOprDelegationStorage<>), | |||
m_opr(m_storage->get<MatrixMul>()) | |||
{ | |||
std::vector<BatchedMatrixMulForwardImpl::Algorithm*> | |||
BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, | |||
const TensorLayout& B, | |||
const TensorLayout& C) { | |||
AlgoBase::SizeArgs args{this, A, B, C}; | |||
return megdnn::get_all_algorithms<BatchedMatrixMulForwardImpl>(args); | |||
} | |||
size_t BatchedMatrixMulImpl::get_workspace_in_bytes( | |||
const TensorLayout &A, const TensorLayout &B, | |||
const TensorLayout &C) { | |||
auto A_ = A.remove_axis(0), B_ = B.remove_axis(0), C_ = C.remove_axis(0); | |||
m_opr->param() = param(); | |||
return m_opr->get_workspace_in_bytes(A_, B_, C_); | |||
BatchedMatrixMulForwardImpl::Algorithm* | |||
BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | |||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||
size_t workspace_limit_in_bytes, bool reproducible) { | |||
AlgoBase::SizeArgs args{this, A, B, C}; | |||
if (sm_algo_pack.algo_default.is_available_reproducible( | |||
args, reproducible, workspace_limit_in_bytes)) { | |||
return &sm_algo_pack.algo_default; | |||
} | |||
if (reproducible) { | |||
return megdnn::get_reproducible_algo<BatchedMatrixMulForwardImpl>( | |||
sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||
"batched matrix mul forward"); | |||
} else { | |||
return megdnn::get_usable_algo<BatchedMatrixMulForwardImpl>( | |||
sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||
"batched matrix mul forward"); | |||
} | |||
} | |||
void BatchedMatrixMulImpl::exec(_megdnn_tensor_in A, | |||
_megdnn_tensor_in B, | |||
_megdnn_tensor_out C, | |||
_megdnn_workspace workspace) { | |||
check_exec(A.layout, B.layout, C.layout, workspace.size); | |||
m_opr->param() = this->param(); | |||
auto kern = [this, A, B, C, workspace]() { | |||
auto N = A.layout.shape[0]; | |||
TensorND A_, B_, C_; | |||
A_.raw_ptr = A.raw_ptr; | |||
A_.layout = A.layout.remove_axis(0); | |||
B_.raw_ptr = B.raw_ptr; | |||
B_.layout = B.layout.remove_axis(0); | |||
C_.raw_ptr = C.raw_ptr; | |||
C_.layout = C.layout.remove_axis(0); | |||
auto Astrd = A.layout.dtype.size() * A.layout.stride[0], | |||
Bstrd = B.layout.dtype.size() * B.layout.stride[0], | |||
Cstrd = C.layout.dtype.size() * C.layout.stride[0]; | |||
auto advance_ptr = [](TensorND &dest, ptrdiff_t d) { | |||
dest.raw_ptr = static_cast<void*>( | |||
static_cast<dt_byte*>(dest.raw_ptr) + d); | |||
}; | |||
rep(n, N) { | |||
m_opr->exec(A_, B_, C_, workspace); | |||
advance_ptr(A_, Astrd); | |||
advance_ptr(B_, Bstrd); | |||
advance_ptr(C_, Cstrd); | |||
} | |||
}; | |||
static_cast<naive::HandleImpl*>(handle())->dispatch_kern(kern); | |||
size_t BatchedMatrixMulForwardImpl::get_workspace_in_bytes( | |||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) { | |||
AlgoBase::SizeArgs args{this, A, B, C}; | |||
return megdnn::get_algorithm(this, A, B, C)->get_workspace_in_bytes(args); | |||
} | |||
void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||
_megdnn_tensor_out C, | |||
_megdnn_workspace workspace) { | |||
check_exec(A.layout, B.layout, C.layout, workspace.size); | |||
AlgoBase::ExecArgs args(this, A, B, C, workspace); | |||
auto&& algo = get_algorithm(this, A.layout, B.layout, C.layout); | |||
algo->check_workspace(args, workspace).exec(args); | |||
} | |||
// vim: syntax=cpp.doxygen | |||
@@ -15,26 +15,42 @@ | |||
namespace megdnn { | |||
namespace fallback { | |||
class BatchedMatrixMulImpl: public naive::BatchedMatrixMulForwardImpl { | |||
public: | |||
BatchedMatrixMulImpl(Handle *handle); | |||
void exec( | |||
_megdnn_tensor_in A, | |||
_megdnn_tensor_in B, | |||
_megdnn_tensor_out C, | |||
_megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout &A, | |||
const TensorLayout &B, | |||
const TensorLayout &C) override; | |||
private: | |||
std::unique_ptr<CpuOprDelegationStorage<>> m_storage; | |||
MatrixMulForward* m_opr; | |||
class BatchedMatrixMulForwardImpl: public naive::BatchedMatrixMulForwardImpl { | |||
public: | |||
using naive::BatchedMatrixMulForwardImpl::BatchedMatrixMulForwardImpl; | |||
void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | |||
_megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&) override; | |||
bool is_thread_safe() const override { return true; } | |||
class AlgoBase; | |||
class AlgoDefault; | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
private: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
const TensorLayout& /*C*/) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, | |||
const TensorLayout& /*B*/, | |||
const TensorLayout& /*C*/, | |||
size_t /*workspace_limit_in_bytes*/, | |||
bool /*reproducible*/) override; | |||
const char* get_algorithm_set_name() const override { | |||
return "FALLBACK BATCHED MATMUL"; | |||
} | |||
static AlgoPack sm_algo_pack; | |||
}; | |||
} // namespace fallback | |||
} // namespace megdnn | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen | |||
@@ -474,6 +474,13 @@ public: | |||
"NoPackStrategyType::FLOAT16_FLOAT16"_hash); | |||
break; | |||
#endif | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
case StrategyType::FLOAT_FP16: | |||
cb1(NCHW, NO_PACK, dt_float16, __fp16, | |||
PostprocessMode::NO_PROCESS, | |||
"NoPackStrategyType::FLOAT_FP16"_hash); | |||
break; | |||
#endif | |||
case StrategyType::INT8x8x16: | |||
cb3(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8, | |||
dt_int16, dt_int16, PostprocessMode::ADD_BIAS, | |||
@@ -169,6 +169,10 @@ INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | |||
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
#endif | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16, | |||
megdnn::PostprocessMode::NO_PROCESS) | |||
#endif | |||
#undef INSTANTIAL_CLASS | |||
} // namespace megdnn | |||
@@ -67,7 +67,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(ElemwiseMultiType) | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(AddUpdate) | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MaskConvForward) | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Resize) | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedMatrixMul) | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedMatrixMulForward) | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvBias) | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(PowC) | |||
@@ -10,13 +10,18 @@ | |||
*/ | |||
#include "src/fallback/matrix_mul/algos.h" | |||
#include "megdnn/opr_param_defs.h" | |||
#include "src/fallback/matrix_mul/gemm_impl.h" | |||
#include "src/fallback/matrix_mul/gemv.h" | |||
#include "src/fallback/matrix_mul/generic_strategy.h" | |||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_fb_matmul_f32_kern) | |||
MIDOUT_DECL(megdnn_fb_matmul_f32_gemm_gemv_like) | |||
MIDOUT_DECL(megdnn_fb_matmul_naive) | |||
using namespace megdnn; | |||
using namespace fallback; | |||
@@ -39,6 +44,32 @@ void f32_8x12x1_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
} | |||
MIDOUT_END(); | |||
} | |||
void kern_naive(const MatrixMulImpl::KernParam& kern_param) { | |||
MIDOUT_BEGIN(megdnn_fb_matmul_naive, void) { | |||
size_t M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
size_t LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
#define DISPATCH(TA, TB) \ | |||
if (kern_param.trA == TA && kern_param.trB == TB) { \ | |||
naive::dispatch_ta_tb<TA, TB>( \ | |||
kern_param.A_ptr, kern_param.B_ptr, kern_param.C_ptr, \ | |||
kern_param.workspace_ptr, M, N, K, LDA, LDB, LDC, \ | |||
kern_param.A_type, kern_param.B_type, kern_param.C_type, \ | |||
kern_param.format, kern_param.compute_mode); \ | |||
return; \ | |||
} | |||
DISPATCH(true, true); | |||
DISPATCH(true, false); | |||
DISPATCH(false, true); | |||
DISPATCH(false, false); | |||
#undef DISPATCH | |||
megdnn_assert_internal(0); | |||
} | |||
MIDOUT_END(); | |||
} | |||
} // anonymous namespace | |||
////////////////////// AlgoF32K8x12x1 /////////////////////////// | |||
@@ -84,11 +115,14 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_fb_matmul_f32_kern, | |||
bool MatrixMulImpl::AlgoGemv::usable( | |||
const KernSizeParam& kern_size_param) const { | |||
return !kern_size_param.trA && !kern_size_param.trB && | |||
kern_size_param.format == param::MatrixMul::Format::DEFAULT && | |||
!((kern_size_param.A_type.enumv() == | |||
kern_size_param.B_type.enumv()) && | |||
(kern_size_param.A_type.enumv() == DTypeEnum::Int16) && | |||
(kern_size_param.C_type.enumv() == DTypeEnum::Int32)); | |||
kern_size_param.format == | |||
param::MatrixMul::Format::DEFAULT && | |||
kern_size_param.compute_mode == | |||
param::MatrixMul::ComputeMode::DEFAULT && | |||
!((kern_size_param.A_type.enumv() == | |||
kern_size_param.B_type.enumv()) && | |||
(kern_size_param.A_type.enumv() == DTypeEnum::Int16) && | |||
(kern_size_param.C_type.enumv() == DTypeEnum::Int32)); | |||
} | |||
bool MatrixMulImpl::AlgoGemv::preferred( | |||
@@ -128,4 +162,44 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoGemv::get_kern( | |||
megdnn_assert(0); | |||
} | |||
/* ===================== naive algo ===================== */ | |||
bool MatrixMulImpl::AlgoNaive::usable(const KernSizeParam&) const { | |||
return true; | |||
} | |||
bool MatrixMulImpl::AlgoNaive::preferred(const KernSizeParam&) const { | |||
return false; | |||
} | |||
size_t MatrixMulImpl::AlgoNaive::get_workspace( | |||
const KernSizeParam& kern_param) const { | |||
MIDOUT_BEGIN( | |||
megdnn_fb_matmul_naive, | |||
midout_iv("MatrixMulForwardImpl::get_workspace_in_bytes"_hash)) { | |||
if (kern_param.A_type.enumv() == DTypeEnum::Quantized4Asymm || | |||
kern_param.A_type.enumv() == DTypeEnum::QuantizedS4) { | |||
size_t ret = 0; | |||
if (kern_param.trA) { | |||
ret += kern_param.LDA * kern_param.K; | |||
} else { | |||
ret += kern_param.LDA * kern_param.M; | |||
} | |||
if (kern_param.trB) { | |||
ret += kern_param.LDB * kern_param.N; | |||
} else { | |||
ret += kern_param.LDB * kern_param.K; | |||
} | |||
return ret; | |||
} | |||
return 0; | |||
} | |||
MIDOUT_END(); | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoNaive::get_kern( | |||
const KernSizeParam&) const { | |||
return kern_naive; | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -52,6 +52,28 @@ public: | |||
DEFAULT) | |||
}; | |||
class MatrixMulImpl::AlgoNaive final : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "FB_NAIVE"; } | |||
bool usable(const KernSizeParam&) const override; | |||
bool preferred(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMM; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_DECL_ALGO_TYPE(FB_NAIVE) | |||
MEGDNN_OVERRIDE_MATMUL_DESC( | |||
8, 16, 1, 4, | |||
static_cast<AlgoDataType>( | |||
static_cast<uint32_t>(AlgoDataType::FLOAT16) | | |||
static_cast<uint32_t>(AlgoDataType::FLOAT32) | | |||
static_cast<uint32_t>(AlgoDataType::INT8X8X16) | | |||
static_cast<uint32_t>(AlgoDataType::QINT8X8X32) | | |||
static_cast<uint32_t>(AlgoDataType::QUINT8X8X32)), | |||
DEFAULT) | |||
}; | |||
} // namespace fallback | |||
} // namespace megdnn | |||
@@ -35,6 +35,7 @@ using namespace fallback; | |||
class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
AlgoF32K8x12x1 f32_k8x12x1; | |||
AlgoGemv gemv; | |||
AlgoNaive naive; | |||
SmallVector<AlgoBase*> m_all_algos; | |||
AlgoBase::Mapper m_all_algos_map; | |||
@@ -42,6 +43,7 @@ public: | |||
AlgoPack() { | |||
m_all_algos.emplace_back(&gemv); | |||
m_all_algos.emplace_back(&f32_k8x12x1); | |||
m_all_algos.emplace_back(&naive); | |||
for (auto&& algo : m_all_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
@@ -147,19 +149,26 @@ MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic( | |||
algo_type.format = kern_size_param.format; | |||
auto algos = select_algo_type(algo_type); | |||
Algorithm *heuristic_algo = nullptr; | |||
Algorithm *usable_algo = nullptr; | |||
for (auto&& algo : algos) { | |||
if (static_cast<AlgoBase*>(algo)->usable(kern_size_param) && | |||
static_cast<AlgoBase*>(algo)->preferred_reproducible( | |||
kern_size_param, reproducible) && | |||
static_cast<AlgoBase*>(algo)->get_workspace(kern_size_param) <= | |||
workspace_limit_in_bytes) { | |||
if (algo->algoset() == AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { | |||
return algo; | |||
} else if (!heuristic_algo) { | |||
heuristic_algo = algo; | |||
if (static_cast<AlgoBase*>(algo)->preferred_reproducible( | |||
kern_size_param, reproducible)) { | |||
//! use gemv algo if it's prefered | |||
if (algo->algoset() == AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { | |||
return algo; | |||
} else if (!heuristic_algo) { | |||
heuristic_algo = algo; | |||
} | |||
} else if (!usable_algo) { | |||
usable_algo = algo; | |||
} | |||
} | |||
} | |||
if (!heuristic_algo) heuristic_algo = usable_algo; | |||
megdnn_assert(heuristic_algo, "No usable algorithm found"); | |||
return heuristic_algo; | |||
} | |||
@@ -110,6 +110,7 @@ public: | |||
//! fallback | |||
FB_F32K8x12x1 = 1 << 0, | |||
FB_GEMV, | |||
FB_NAIVE, | |||
#if MEGDNN_X86 | |||
//! x86 | |||
@@ -233,6 +234,7 @@ public: | |||
private: | |||
class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1 | |||
class AlgoGemv; | |||
class AlgoNaive; | |||
class AlgoPack; | |||
//! maintain all the algos of in the opr of fallback | |||
static const AlgoPack& algo_pack(); | |||
@@ -141,20 +141,39 @@ void run_matrix_mul_mk8_tpl(const itype* A, const itype* B, otype* C, size_t M, | |||
} | |||
template <bool transA, bool transB> | |||
void exec_matrix_mul_quint4x4x32_helper(_megdnn_tensor_in A, | |||
_megdnn_tensor_in B, | |||
_megdnn_tensor_out C, | |||
_megdnn_workspace workspace, | |||
const param::MatrixMul& param) { | |||
void exec_matrix_mul_quint4x4x32_helper( | |||
const void* A, const void* B, void* C, void* workspace, size_t M, | |||
size_t N, size_t K, ptrdiff_t LDA, ptrdiff_t LDB, ptrdiff_t LDC, | |||
DType A_type, DType B_type, DType C_type, | |||
const MatrixMul::Param::Format& format, | |||
const MatrixMul::Param::ComputeMode& compute_mode) { | |||
MEGDNN_MARK_USED_VAR(C_type); | |||
MEGDNN_MARK_USED_VAR(format); | |||
MEGDNN_MARK_USED_VAR(compute_mode); | |||
auto convert_layout = [](const TensorLayout& layout) { | |||
auto ret = layout; | |||
auto param = layout.dtype.param<dtype::Quantized4Asymm>(); | |||
ret.dtype = dtype::Quantized8Asymm(param.scale, param.zero_point); | |||
return ret; | |||
}; | |||
TensorND nA = {workspace.raw_ptr, convert_layout(A.layout)}; | |||
TensorND nB = {workspace.raw_ptr + nA.layout.span().dist_byte(), | |||
convert_layout(B.layout)}; | |||
TensorLayout A_layout, B_layout; | |||
if (transA) { | |||
A_layout = TensorLayout({K, M}, {LDA, 1}, A_type); | |||
} else { | |||
A_layout = TensorLayout({M, K}, {LDA, 1}, A_type); | |||
} | |||
if (transB) { | |||
B_layout = TensorLayout({N, K}, {LDB, 1}, B_type); | |||
} else { | |||
B_layout = TensorLayout({K, N}, {LDB, 1}, B_type); | |||
} | |||
TensorND tensorA{const_cast<void*>(A), A_layout}; | |||
TensorND tensorB{const_cast<void*>(B), B_layout}; | |||
TensorND nA = {workspace, convert_layout(A_layout)}; | |||
TensorND nB = { | |||
static_cast<uint8_t*>(workspace) + nA.layout.span().dist_byte(), | |||
convert_layout(B_layout)}; | |||
auto convert_4to8 = [](const TensorND& in, const TensorND& out) { | |||
auto ptr = | |||
static_cast<uint8_t*>(in.raw_ptr) + in.layout.span().low_byte; | |||
@@ -168,31 +187,48 @@ void exec_matrix_mul_quint4x4x32_helper(_megdnn_tensor_in A, | |||
out_ptr[i + 1] = val1; | |||
} | |||
}; | |||
convert_4to8(A, nA); | |||
convert_4to8(B, nB); | |||
auto M = C.layout.shape[0], N = C.layout.shape[1]; | |||
auto K = A.layout.shape[param.transposeA ? 0 : 1]; | |||
auto LDA = A.layout.stride[0], LDB = B.layout.stride[0], | |||
LDC = C.layout.stride[0]; | |||
convert_4to8(tensorA, nA); | |||
convert_4to8(tensorB, nB); | |||
run_matrix_mul_tpl<uint8_t, dt_int32, transA, transB, dt_int32>( | |||
nA.compatible_ptr<uint8_t>(), nB.compatible_ptr<uint8_t>(), | |||
C.compatible_ptr<dt_int32>(), M, N, K, LDA, LDB, LDC, | |||
nA.layout.dtype, nB.layout.dtype); | |||
static_cast<dt_int32*>(C), M, N, K, LDA, LDB, LDC, nA.layout.dtype, | |||
nB.layout.dtype); | |||
} | |||
template <bool transA, bool transB> | |||
void exec_matrix_mul_qint4x4x16_helper(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||
_megdnn_tensor_out C, | |||
_megdnn_workspace workspace, | |||
const param::MatrixMul& param) { | |||
void exec_matrix_mul_qint4x4x16_helper( | |||
const void* A, const void* B, void* C, void* workspace, size_t M, | |||
size_t N, size_t K, ptrdiff_t LDA, ptrdiff_t LDB, ptrdiff_t LDC, | |||
DType A_type, DType B_type, DType C_type, | |||
const MatrixMul::Param::Format& format, | |||
const MatrixMul::Param::ComputeMode& compute_mode) { | |||
MEGDNN_MARK_USED_VAR(C_type); | |||
MEGDNN_MARK_USED_VAR(format); | |||
MEGDNN_MARK_USED_VAR(compute_mode); | |||
auto convert_layout = [](const TensorLayout& layout) { | |||
auto ret = layout; | |||
auto param = layout.dtype.param<dtype::QuantizedS4>(); | |||
ret.dtype = dtype::QuantizedS8(param.scale); | |||
return ret; | |||
}; | |||
TensorND nA = {workspace.raw_ptr, convert_layout(A.layout)}; | |||
TensorND nB = {workspace.raw_ptr + nA.layout.span().dist_byte(), | |||
convert_layout(B.layout)}; | |||
TensorLayout A_layout, B_layout; | |||
if (transA) { | |||
A_layout = TensorLayout({K, M}, {LDA, 1}, A_type); | |||
} else { | |||
A_layout = TensorLayout({M, K}, {LDA, 1}, A_type); | |||
} | |||
if (transB) { | |||
B_layout = TensorLayout({N, K}, {LDB, 1}, B_type); | |||
} else { | |||
B_layout = TensorLayout({K, N}, {LDB, 1}, B_type); | |||
} | |||
TensorND tensorA{const_cast<void*>(A), A_layout}; | |||
TensorND tensorB{const_cast<void*>(B), B_layout}; | |||
TensorND nA = {workspace, convert_layout(A_layout)}; | |||
TensorND nB = { | |||
static_cast<uint8_t*>(workspace) + nA.layout.span().dist_byte(), | |||
convert_layout(B_layout)}; | |||
auto convert_4to8 = [](const TensorND& in, const TensorND& out) { | |||
auto ptr = static_cast<int8_t*>(in.raw_ptr) + in.layout.span().low_byte; | |||
auto out_ptr = | |||
@@ -204,18 +240,98 @@ void exec_matrix_mul_qint4x4x16_helper(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||
out_ptr[i + 1] = cur >> 4; | |||
} | |||
}; | |||
convert_4to8(A, nA); | |||
convert_4to8(B, nB); | |||
auto M = C.layout.shape[0], N = C.layout.shape[1]; | |||
auto K = A.layout.shape[param.transposeA ? 0 : 1]; | |||
auto LDA = A.layout.stride[0], LDB = B.layout.stride[0], | |||
LDC = C.layout.stride[0]; | |||
convert_4to8(tensorA, nA); | |||
convert_4to8(tensorB, nB); | |||
run_matrix_mul_tpl<int8_t, dt_int16, transA, transB, dt_int16>( | |||
nA.compatible_ptr<int8_t>(), nB.compatible_ptr<int8_t>(), | |||
C.compatible_ptr<dt_int16>(), M, N, K, LDA, LDB, LDC, | |||
nA.layout.dtype, nB.layout.dtype); | |||
static_cast<dt_int16*>(C), M, N, K, LDA, LDB, LDC, nA.layout.dtype, | |||
nB.layout.dtype); | |||
} | |||
template <bool TA, bool TB> | |||
void dispatch_ta_tb(const void* A, const void* B, void* C, void* workspace, | |||
size_t M, size_t N, size_t K, ptrdiff_t LDA, ptrdiff_t LDB, | |||
ptrdiff_t LDC, DType A_type, DType B_type, DType C_type, | |||
const MatrixMul::Param::Format& format, | |||
const MatrixMul::Param::ComputeMode& compute_mode) { | |||
#define cb(_itype, _otype, _comp_type) \ | |||
if (format == param::MatrixMul::Format::DEFAULT) { \ | |||
return run_matrix_mul_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||
static_cast<const _itype*>(A), static_cast<const _itype*>(B), \ | |||
static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, \ | |||
B_type); \ | |||
} else if (format == param::MatrixMul::Format::MK4) { \ | |||
return run_matrix_mul_mk4_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||
static_cast<const _itype*>(A), static_cast<const _itype*>(B), \ | |||
static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, \ | |||
B_type); \ | |||
} else if (format == param::MatrixMul::Format::MK4_DOT) { \ | |||
return run_matrix_mul_mk4_dot_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||
static_cast<const _itype*>(A), static_cast<const _itype*>(B), \ | |||
static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, \ | |||
B_type); \ | |||
} else if (format == param::MatrixMul::Format::MK8) { \ | |||
return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||
static_cast<const _itype*>(A), static_cast<const _itype*>(B), \ | |||
static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, \ | |||
B_type); \ | |||
} | |||
if (A_type == dtype::Float32()) { | |||
cb(dt_float32, dt_float32, dt_float32); | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
} else if (A_type == dtype::Float16()) { | |||
using Param = MatrixMul::Param; | |||
if (compute_mode == Param::ComputeMode::DEFAULT) { | |||
cb(dt_float16, dt_float16, dt_float16); | |||
} else if (compute_mode == Param::ComputeMode::FLOAT32) { | |||
cb(dt_float16, dt_float16, dt_float32); | |||
} | |||
} else if (A_type == dtype::BFloat16()) { | |||
using Param = MatrixMul::Param; | |||
if (compute_mode == Param::ComputeMode::DEFAULT) { | |||
cb(dt_bfloat16, dt_bfloat16, dt_bfloat16); | |||
} else if (compute_mode == Param::ComputeMode::FLOAT32) { | |||
cb(dt_bfloat16, dt_bfloat16, dt_float32); | |||
} | |||
#endif | |||
} else if (A_type == dtype::Int8() && | |||
C_type == dtype::Int16()) { | |||
cb(dt_int8, dt_int16, dt_int16); | |||
} else if (A_type == dtype::Int16() && | |||
C_type == dtype::Int32()) { | |||
cb(dt_int16, dt_int32, dt_int32); | |||
} else if ((A_type == dtype::Int8() || | |||
A_type.enumv() == DTypeEnum::QuantizedS8) && | |||
(C_type == dtype::Int32() || | |||
C_type.enumv() == DTypeEnum::QuantizedS32)) { | |||
cb(dt_int8, dt_int32, dt_int32); | |||
} else if (A_type.enumv() == DTypeEnum::Quantized8Asymm && | |||
C_type.enumv() == DTypeEnum::QuantizedS32) { | |||
cb(uint8_t, dt_int32, dt_int32); | |||
} else if (A_type.enumv() == DTypeEnum::Quantized4Asymm && | |||
C_type.enumv() == DTypeEnum::QuantizedS32 && | |||
format == param::MatrixMul::Format::DEFAULT) { | |||
exec_matrix_mul_quint4x4x32_helper<TA, TB>( | |||
A, B, C, workspace, M, N, K, LDA, LDB, LDC, A_type, B_type, | |||
C_type, format, compute_mode); | |||
return; | |||
} else if (A_type.enumv() == DTypeEnum::QuantizedS4 && | |||
C_type.enumv() == DTypeEnum::QuantizedS16 && | |||
format == param::MatrixMul::Format::DEFAULT) { | |||
exec_matrix_mul_qint4x4x16_helper<TA, TB>( | |||
A, B, C, workspace, M, N, K, LDA, LDB, LDC, A_type, B_type, | |||
C_type, format, compute_mode); | |||
return; | |||
} | |||
#undef cb | |||
megdnn_throw( | |||
ssprintf("unsupported naive MatrixMul(%s, %s) -> %s (cmode = %d)", | |||
A_type.name(), B_type.name(), C_type.name(), | |||
static_cast<int>(compute_mode))); | |||
} | |||
} // namespace naive | |||
} // namespace megdnn | |||
@@ -45,77 +45,10 @@ void dispatch_ta_tb(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||
auto LDA = A.layout.stride[0], LDB = B.layout.stride[0], | |||
LDC = C.layout.stride[0]; | |||
#define cb(_itype, _otype, _comp_type) \ | |||
if (param.format == param::MatrixMul::Format::DEFAULT) { \ | |||
return run_matrix_mul_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ | |||
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ | |||
A.layout.dtype, B.layout.dtype); \ | |||
} else if (param.format == param::MatrixMul::Format::MK4) { \ | |||
return run_matrix_mul_mk4_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ | |||
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ | |||
A.layout.dtype, B.layout.dtype); \ | |||
} else if (param.format == param::MatrixMul::Format::MK4_DOT) { \ | |||
return run_matrix_mul_mk4_dot_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ | |||
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ | |||
A.layout.dtype, B.layout.dtype); \ | |||
} else if (param.format == param::MatrixMul::Format::MK8) { \ | |||
return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ | |||
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ | |||
A.layout.dtype, B.layout.dtype); \ | |||
} | |||
if (A.layout.dtype == dtype::Float32()) { | |||
cb(dt_float32, dt_float32, dt_float32); | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
} else if (A.layout.dtype == dtype::Float16()) { | |||
using Param = MatrixMul::Param; | |||
if (param.compute_mode == Param::ComputeMode::DEFAULT) { | |||
cb(dt_float16, dt_float16, dt_float16); | |||
} else if (param.compute_mode == Param::ComputeMode::FLOAT32) { | |||
cb(dt_float16, dt_float16, dt_float32); | |||
} | |||
} else if (A.layout.dtype == dtype::BFloat16()) { | |||
using Param = MatrixMul::Param; | |||
if (param.compute_mode == Param::ComputeMode::DEFAULT) { | |||
cb(dt_bfloat16, dt_bfloat16, dt_bfloat16); | |||
} else if (param.compute_mode == Param::ComputeMode::FLOAT32) { | |||
cb(dt_bfloat16, dt_bfloat16, dt_float32); | |||
} | |||
#endif | |||
} else if (A.layout.dtype == dtype::Int8() && | |||
C.layout.dtype == dtype::Int16()) { | |||
cb(dt_int8, dt_int16, dt_int16); | |||
} else if (A.layout.dtype == dtype::Int16() && | |||
C.layout.dtype == dtype::Int32()) { | |||
cb(dt_int16, dt_int32, dt_int32); | |||
} else if ((A.layout.dtype == dtype::Int8() || | |||
A.layout.dtype.enumv() == DTypeEnum::QuantizedS8) && | |||
(C.layout.dtype == dtype::Int32() || | |||
C.layout.dtype.enumv() == DTypeEnum::QuantizedS32)) { | |||
cb(dt_int8, dt_int32, dt_int32); | |||
} else if (A.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && | |||
C.layout.dtype.enumv() == DTypeEnum::QuantizedS32) { | |||
cb(uint8_t, dt_int32, dt_int32); | |||
} else if (A.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm && | |||
C.layout.dtype.enumv() == DTypeEnum::QuantizedS32 && | |||
param.format == param::MatrixMul::Format::DEFAULT) { | |||
exec_matrix_mul_quint4x4x32_helper<TA, TB>(A, B, C, workspace, param); | |||
return; | |||
} else if (A.layout.dtype.enumv() == DTypeEnum::QuantizedS4 && | |||
C.layout.dtype.enumv() == DTypeEnum::QuantizedS16 && | |||
param.format == param::MatrixMul::Format::DEFAULT) { | |||
exec_matrix_mul_qint4x4x16_helper<TA, TB>(A, B, C, workspace, param); | |||
return; | |||
} | |||
#undef cb | |||
megdnn_throw(ssprintf( | |||
"unsupported naive MatrixMul(%s, %s) -> %s (cmode = %d)", | |||
A.layout.dtype.name(), B.layout.dtype.name(), C.layout.dtype.name(), | |||
static_cast<int>(param.compute_mode))); | |||
dispatch_ta_tb<TA, TB>(A.raw_ptr, B.raw_ptr, C.raw_ptr, workspace.raw_ptr, | |||
M, N, K, LDA, LDB, LDC, A.layout.dtype, | |||
B.layout.dtype, C.layout.dtype, param.format, | |||
param.compute_mode); | |||
} | |||
void MatrixMulForwardImpl::exec_internal(_megdnn_tensor_in A, | |||
@@ -0,0 +1,59 @@ | |||
/** | |||
* \file dnn/src/rocm/batched_matrix_mul/algos.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/rocm/batched_matrix_mul/algos.h" | |||
#include "src/common/algo_base.h" | |||
using namespace megdnn; | |||
using namespace rocm; | |||
BatchedMatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||
all_algos.push_back(&blas); | |||
for (auto&& algo : all_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
} | |||
BatchedMatrixMulForwardImpl::AlgoPack BatchedMatrixMulForwardImpl::sm_algo_pack; | |||
MEGDNN_DEF_GET_ALGO_FROM_DESC(BatchedMatrixMulForwardImpl) | |||
BatchedMatrixMulForwardImpl::AlgoBase::SizeArgs::SizeArgs( | |||
BatchedMatrixMulForwardImpl* o, const TensorLayout& A, | |||
const TensorLayout& B, const TensorLayout& C) | |||
: opr{o}, layout_a{A}, layout_b{B}, layout_c{C} {} | |||
BatchedMatrixMulForwardImpl::AlgoBase::ExecArgs::ExecArgs( | |||
BatchedMatrixMulForwardImpl* opr, _megdnn_tensor_in A, | |||
_megdnn_tensor_in B, _megdnn_tensor_out C, _megdnn_workspace workspace) | |||
: SizeArgs(opr, A.layout, B.layout, C.layout), | |||
tensor_a{A}, | |||
tensor_b{B}, | |||
tensor_c{C}, | |||
workspace{workspace} {} | |||
std::string BatchedMatrixMulForwardImpl::AlgoBase::SizeArgs::to_string() const { | |||
auto&& param = opr->param(); | |||
size_t m = layout_a.shape[0], n = layout_b.shape[1], | |||
k = layout_a.shape[param.transposeA ? 0 : 1]; | |||
MEGDNN_MARK_USED_VAR(m); | |||
MEGDNN_MARK_USED_VAR(n); | |||
MEGDNN_MARK_USED_VAR(k); | |||
return megdnn_mangle(ssprintf( | |||
"A={%zux%zu},B={%zux%zu},C={%zux%zu},Transpose A=%d,Transpose " | |||
"B=%d,ldA=%zu,ldB=%zu,ldC=%zu", | |||
m, k, k, n, m, n, param.transposeA, param.transposeB, | |||
layout_a.stride[0], layout_b.stride[0], layout_c.stride[0])); | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,118 @@ | |||
/** | |||
* \file dnn/src/rocm/batched_matrix_mul/algos.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
#include "src/common/algo_base.h" | |||
#include "src/common/metahelper.h" | |||
#include "src/common/utils.h" | |||
#include "src/rocm/batched_matrix_mul/opr_impl.h" | |||
#include <memory> | |||
#include <unordered_map> | |||
namespace megdnn { | |||
namespace rocm { | |||
/*! | |||
* \brief base class for matrix mul algos | |||
* | |||
*/ | |||
class BatchedMatrixMulForwardImpl::AlgoBase : public Algorithm { | |||
protected: | |||
~AlgoBase() = default; | |||
public: | |||
enum class AlgoType : uint32_t { | |||
ROCM_BLAS, | |||
}; | |||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; } | |||
struct SizeArgs { | |||
BatchedMatrixMulForwardImpl* opr; | |||
TensorLayout layout_a, layout_b, layout_c; | |||
std::string to_string() const; | |||
SizeArgs(BatchedMatrixMulForwardImpl* opr, const TensorLayout& A, | |||
const TensorLayout& B, const TensorLayout& C); | |||
bool can_be_treated_as_int8x8x32() const { | |||
return layout_a.dtype.enumv() == layout_b.dtype.enumv() && | |||
(layout_a.dtype.enumv() == DTypeEnum::Int8 || | |||
layout_a.dtype.enumv() == DTypeEnum::QuantizedS8) && | |||
(layout_c.dtype.enumv() == DTypeEnum::Int32 || | |||
layout_c.dtype.enumv() == DTypeEnum::QuantizedS32) && | |||
opr->param().format == param::MatrixMul::Format::DEFAULT; | |||
} | |||
}; | |||
struct ExecArgs : public SizeArgs { | |||
TensorND tensor_a, tensor_b, tensor_c; | |||
Workspace workspace; | |||
ExecArgs(BatchedMatrixMulForwardImpl* opr, _megdnn_tensor_in A, | |||
_megdnn_tensor_in B, _megdnn_tensor_out C, | |||
_megdnn_workspace workspace); | |||
}; | |||
virtual bool is_available(const SizeArgs& args) const = 0; | |||
virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||
virtual void exec(const ExecArgs& args) const = 0; | |||
bool is_available_wk(const SizeArgs& args, size_t limit) const { | |||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
} | |||
bool is_available_reproducible( | |||
const SizeArgs& args, bool reproducible = true, | |||
size_t limit = std::numeric_limits<size_t>::max()) const { | |||
return (!reproducible || is_reproducible()) && | |||
is_available_wk(args, limit); | |||
} | |||
AlgoBase& check_workspace(const SizeArgs& args, | |||
const Workspace& workspace) { | |||
auto req = get_workspace_in_bytes(args); | |||
megdnn_assert( | |||
req <= workspace.size, | |||
"matrix mul fwd algo %s: required workspace %zu bytes, got %zu", | |||
name(), req, workspace.size); | |||
return *this; | |||
} | |||
}; | |||
class BatchedMatrixMulForwardImpl::AlgoBlas final : public AlgoBase { | |||
public: | |||
AlgoBlas() = default; | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override { | |||
return 0_z; | |||
} | |||
const char* name() const override { return "BLAS"; } | |||
void exec(const ExecArgs& args) const override; | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(ROCM_BLAS) | |||
}; | |||
class BatchedMatrixMulForwardImpl::AlgoPack : NonCopyableObj { | |||
private: | |||
AlgoBase::Mapper m_all_algos_map; | |||
public: | |||
AlgoPack(); | |||
AlgoBlas blas; | |||
std::vector<AlgoBase*> all_algos; | |||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
}; | |||
} // namespace rocm | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,140 @@ | |||
/** | |||
* \file dnn/src/rocm/batched_matrix_mul/Blas.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/rocm/batched_matrix_mul/algos.h" | |||
#include "hcc_detail/hcc_defs_prologue.h" | |||
#include "src/rocm/handle.h" | |||
#include "src/rocm/utils.h" | |||
using namespace megdnn; | |||
using namespace rocm; | |||
bool BatchedMatrixMulForwardImpl::AlgoBlas::is_available( | |||
const SizeArgs& args) const { | |||
if (args.opr->param().format != param::MatrixMul::Format::DEFAULT) | |||
return false; | |||
if (args.layout_a.dtype == dtype::Float32() || | |||
args.layout_a.dtype == dtype::Float16()) { | |||
return true; | |||
} | |||
return false; | |||
} | |||
void BatchedMatrixMulForwardImpl::AlgoBlas::exec(const ExecArgs& args) const { | |||
auto batch = args.layout_a.shape[0]; | |||
auto m = args.layout_c.shape[1], n = args.layout_c.shape[2]; | |||
auto k = args.layout_a.shape[args.opr->param().transposeA ? 1 : 2]; | |||
auto&& handle = concrete_handle(args.opr->handle()); | |||
auto rocblas_handle_ = handle->get_rocblas_handle(); | |||
auto sgemm = [&]() { | |||
auto zero = handle->zero_device(); | |||
auto one = handle->one_device(); | |||
rocblas_check(rocblas_sgemm_strided_batched( | |||
rocblas_handle_, | |||
args.opr->param().transposeB ? rocblas_operation_transpose | |||
: rocblas_operation_none, | |||
args.opr->param().transposeA ? rocblas_operation_transpose | |||
: rocblas_operation_none, | |||
n, m, k, one, args.tensor_b.ptr<dt_float32>(), | |||
(rocblas_int)(args.layout_b.stride[1]), | |||
(rocblas_int)(args.layout_b.stride[0]), | |||
args.tensor_a.ptr<dt_float32>(), | |||
(rocblas_int)(args.layout_a.stride[1]), | |||
(rocblas_int)(args.layout_a.stride[0]), zero, | |||
args.tensor_c.ptr<dt_float32>(), | |||
(rocblas_int)(args.layout_c.stride[1]), | |||
(rocblas_int)(args.layout_c.stride[0]), (rocblas_int)(batch))); | |||
}; | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
//! used for FLOAT_IO16xC32, not tested | |||
auto gemm_ex = [&]() { | |||
auto zero = handle->zero_device(); | |||
auto one = handle->one_device(); | |||
//! These two arguments for future use, see | |||
//! https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/src/blas_ex/rocblas_gemm_ex.cpp | |||
int32_t solution_index = 0; | |||
uint32_t flags = 1; | |||
size_t ws_size = 0; | |||
rocblas_check(rocblas_gemm_strided_batched_ex( | |||
rocblas_handle_, | |||
args.opr->param().transposeB ? rocblas_operation_transpose | |||
: rocblas_operation_none, | |||
args.opr->param().transposeA ? rocblas_operation_transpose | |||
: rocblas_operation_none, | |||
n, m, k, one, args.tensor_b.raw_ptr, rocblas_datatype_i8_r, | |||
args.layout_b.stride[1], args.layout_b.stride[0], | |||
args.tensor_a.raw_ptr, rocblas_datatype_i8_r, | |||
args.layout_a.stride[1], args.layout_a.stride[0], zero, | |||
args.tensor_c.raw_ptr, rocblas_datatype_i32_r, | |||
args.layout_c.stride[1], args.layout_c.stride[0], | |||
args.tensor_c.raw_ptr, rocblas_datatype_i32_r, | |||
args.layout_c.stride[1], args.layout_c.stride[0], batch, | |||
rocblas_datatype_i32_r, rocblas_gemm_algo_standard, | |||
solution_index, flags, &ws_size, nullptr)); | |||
MEGDNN_MARK_USED_VAR(ws_size); | |||
}; | |||
auto hgemm = [&]() { | |||
auto one_half = handle->one_device_h(); | |||
auto zero_half = handle->zero_device_h(); | |||
rocblas_check(rocblas_hgemm_strided_batched( | |||
rocblas_handle_, | |||
args.opr->param().transposeB ? rocblas_operation_transpose | |||
: rocblas_operation_none, | |||
args.opr->param().transposeA ? rocblas_operation_transpose | |||
: rocblas_operation_none, | |||
n, m, k, reinterpret_cast<const rocblas_half*>(one_half), | |||
static_cast<const rocblas_half*>(args.tensor_b.raw_ptr), | |||
args.layout_b.stride[1], args.layout_b.stride[0], | |||
static_cast<const rocblas_half*>(args.tensor_a.raw_ptr), | |||
args.layout_a.stride[1], args.layout_a.stride[0], | |||
reinterpret_cast<const rocblas_half*>(zero_half), | |||
static_cast<rocblas_half*>(args.tensor_c.raw_ptr), | |||
args.layout_c.stride[1], args.layout_c.stride[0], batch)); | |||
}; | |||
#endif | |||
if (args.opr->param().compute_mode == Param::ComputeMode::DEFAULT) { | |||
if (args.layout_a.dtype == dtype::Float32()) { | |||
sgemm(); | |||
} | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
else { | |||
megdnn_assert(args.layout_a.dtype == dtype::Float16(), | |||
"invalid matmul data type"); | |||
hgemm(); | |||
} | |||
#endif | |||
} | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
else if (args.opr->param().compute_mode == Param::ComputeMode::FLOAT32) { | |||
megdnn_assert(args.layout_b.dtype == dtype::Float16() && | |||
args.layout_c.dtype == dtype::Float16() && | |||
args.layout_a.dtype == dtype::Float16(), | |||
"DataType::FLOAT_IO16xC32 is supported, when dtype of A, " | |||
"B, C are all Float16"); | |||
gemm_ex(); | |||
} | |||
#endif | |||
else { | |||
megdnn_throw("Unsupported data_type of matrix mul on rocm."); | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -10,111 +10,58 @@ | |||
* implied. | |||
*/ | |||
#include "./opr_impl.h" | |||
#include "./algos.h" | |||
#include "hcc_detail/hcc_defs_prologue.h" | |||
#include "src/common/algo_chooser.h" | |||
#include "src/common/utils.cuh" | |||
#include "src/rocm/handle.h" | |||
#include "src/rocm/utils.h" | |||
namespace megdnn { | |||
namespace rocm { | |||
using namespace megdnn; | |||
using namespace rocm; | |||
std::vector<BatchedMatrixMulForwardImpl::Algorithm*> | |||
BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, | |||
const TensorLayout& B, | |||
const TensorLayout& C) { | |||
AlgoBase::SizeArgs args{this, A, B, C}; | |||
return megdnn::get_all_algorithms<BatchedMatrixMulForwardImpl>(args); | |||
} | |||
BatchedMatrixMulForwardImpl::Algorithm* | |||
BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | |||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||
size_t workspace_limit_in_bytes, bool reproducible) { | |||
AlgoBase::SizeArgs args{this, A, B, C}; | |||
if (sm_algo_pack.blas.is_available_reproducible(args, reproducible, | |||
workspace_limit_in_bytes)) { | |||
return &sm_algo_pack.blas; | |||
} | |||
if (reproducible) { | |||
return megdnn::get_reproducible_algo<BatchedMatrixMulForwardImpl>( | |||
sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||
"batched matrix mul forward"); | |||
} else { | |||
return megdnn::get_usable_algo<BatchedMatrixMulForwardImpl>( | |||
sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||
"batched matrix mul forward"); | |||
} | |||
} | |||
size_t BatchedMatrixMulForwardImpl::get_workspace_in_bytes( | |||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) { | |||
AlgoBase::SizeArgs args{this, A, B, C}; | |||
return megdnn::get_algorithm(this, A, B, C)->get_workspace_in_bytes(args); | |||
} | |||
void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||
_megdnn_tensor_out C, | |||
_megdnn_workspace workspace) { | |||
check_exec(A.layout, B.layout, C.layout, workspace.size); | |||
auto dtype = A.layout.dtype; | |||
megdnn_assert(dtype.category() == DTypeCategory::FLOAT && | |||
param().format == param::MatrixMul::Format::DEFAULT); | |||
if (dtype == dtype::Float32() || | |||
MEGDNN_FLOAT16_SELECT(dtype == dtype::Float16(), false)) { | |||
auto batch = A.layout.shape[0]; | |||
auto m = C.layout.shape[1], n = C.layout.shape[2]; | |||
auto k = A.layout.shape[param().transposeA ? 1 : 2]; | |||
auto handle = concrete_handle(this->handle()); | |||
auto rocblas_handle_ = handle->get_rocblas_handle(); | |||
auto io32_c32 = [&]() { | |||
auto zero = handle->zero_device(); | |||
auto one = handle->one_device(); | |||
rocblas_check(rocblas_sgemm_strided_batched( | |||
rocblas_handle_, | |||
param().transposeB ? rocblas_operation_transpose | |||
: rocblas_operation_none, | |||
param().transposeA ? rocblas_operation_transpose | |||
: rocblas_operation_none, | |||
n, m, k, one, B.ptr<dt_float32>(), | |||
(rocblas_int)(B.layout.stride[1]), | |||
(rocblas_int)(B.layout.stride[0]), A.ptr<dt_float32>(), | |||
(rocblas_int)(A.layout.stride[1]), | |||
(rocblas_int)(A.layout.stride[0]), zero, | |||
C.ptr<dt_float32>(), (rocblas_int)(C.layout.stride[1]), | |||
(rocblas_int)(C.layout.stride[0]), (rocblas_int)(batch))); | |||
}; | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
auto io16_c32 = [&]() { | |||
auto zero = handle->zero_device(); | |||
auto one = handle->one_device(); | |||
int32_t solution_index = 0; | |||
uint32_t flags = 1; | |||
size_t ws_size = 0; | |||
rocblas_check(rocblas_gemm_strided_batched_ex( | |||
rocblas_handle_, | |||
param().transposeB ? rocblas_operation_transpose | |||
: rocblas_operation_none, | |||
param().transposeA ? rocblas_operation_transpose | |||
: rocblas_operation_none, | |||
n, m, k, one, B.raw_ptr, rocblas_datatype_i8_r, | |||
B.layout.stride[1], B.layout.stride[0], A.raw_ptr, | |||
rocblas_datatype_i8_r, A.layout.stride[1], | |||
A.layout.stride[0], zero, C.raw_ptr, rocblas_datatype_i32_r, | |||
C.layout.stride[1], C.layout.stride[0], C.raw_ptr, | |||
rocblas_datatype_i32_r, C.layout.stride[1], | |||
C.layout.stride[0], batch, rocblas_datatype_i32_r, | |||
rocblas_gemm_algo_standard, solution_index, flags, &ws_size, | |||
nullptr)); | |||
}; | |||
auto io16_c16 = [&]() { | |||
auto zero_half = handle->zero_device_h(); | |||
auto one_half = handle->one_device_h(); | |||
rocblas_check(rocblas_hgemm_strided_batched( | |||
rocblas_handle_, | |||
param().transposeB ? rocblas_operation_transpose | |||
: rocblas_operation_none, | |||
param().transposeA ? rocblas_operation_transpose | |||
: rocblas_operation_none, | |||
n, m, k, reinterpret_cast<const rocblas_half*>(one_half), | |||
static_cast<const rocblas_half*>(B.raw_ptr), | |||
B.layout.stride[1], B.layout.stride[0], | |||
static_cast<const rocblas_half*>(A.raw_ptr), | |||
A.layout.stride[1], A.layout.stride[0], | |||
reinterpret_cast<const rocblas_half*>(zero_half), | |||
static_cast<rocblas_half*>(C.raw_ptr), C.layout.stride[1], | |||
C.layout.stride[0], batch)); | |||
}; | |||
#endif | |||
if (dtype == dtype::Float32()) { | |||
io32_c32(); | |||
} | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
else { | |||
if (param().compute_mode == Param::ComputeMode::FLOAT32) { | |||
io16_c32(); | |||
} else { | |||
io16_c16(); | |||
} | |||
} | |||
#endif | |||
} | |||
AlgoBase::ExecArgs args(this, A, B, C, workspace); | |||
auto&& algo = get_algorithm(this, A.layout, B.layout, C.layout); | |||
algo->check_workspace(args, workspace).exec(args); | |||
} | |||
} // namespace rocm | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -6,7 +6,8 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
@@ -17,36 +18,35 @@ namespace rocm { | |||
class BatchedMatrixMulForwardImpl : public BatchedMatrixMulForward { | |||
public: | |||
using BatchedMatrixMulForward::BatchedMatrixMulForward; | |||
BatchedMatrixMulForwardImpl(Handle* handle) | |||
: BatchedMatrixMul(handle), | |||
m_opr(handle->create_operator<MatrixMul>()) {} | |||
void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | |||
_megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&) override { | |||
return 0; | |||
} | |||
const TensorLayout&) override; | |||
bool is_thread_safe() const override { return true; } | |||
class AlgoBase; | |||
class AlgoBlas; | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
private: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
const TensorLayout& /*C*/) override { | |||
return {}; | |||
} | |||
const TensorLayout& /*C*/) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, | |||
const TensorLayout& /*B*/, | |||
const TensorLayout& /*C*/, | |||
size_t /*workspace_limit_in_bytes*/, | |||
bool /* reproducible */) override { | |||
return nullptr; | |||
} | |||
const char* get_algorithm_set_name() const override { return "DEFAULT"; } | |||
bool /*reproducible*/) override; | |||
bool is_thread_safe() const override { return true; } | |||
const char* get_algorithm_set_name() const override { | |||
return "ROCM BATCHED MATMUL"; | |||
} | |||
private: | |||
std::unique_ptr<MatrixMul> m_opr; | |||
static AlgoPack sm_algo_pack; | |||
}; | |||
} // namespace rocm | |||
@@ -0,0 +1,62 @@ | |||
/** | |||
* \file dnn/src/rocm/matrix_mul/algos.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/rocm/matrix_mul/algos.h" | |||
#include "src/common/algo_base.h" | |||
using namespace megdnn; | |||
using namespace rocm; | |||
MatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||
all_algos.push_back(&blas); | |||
for (auto&& algo : all_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
} | |||
MatrixMulForwardImpl::AlgoPack MatrixMulForwardImpl::sm_algo_pack; | |||
MEGDNN_DEF_GET_ALGO_FROM_DESC(MatrixMulForwardImpl) | |||
MatrixMulForwardImpl::AlgoBase::SizeArgs::SizeArgs(MatrixMulForwardImpl* o, | |||
const TensorLayout& A, | |||
const TensorLayout& B, | |||
const TensorLayout& C) | |||
: opr{o}, layout_a{A}, layout_b{B}, layout_c{C} {} | |||
MatrixMulForwardImpl::AlgoBase::ExecArgs::ExecArgs(MatrixMulForwardImpl* opr, | |||
_megdnn_tensor_in A, | |||
_megdnn_tensor_in B, | |||
_megdnn_tensor_out C, | |||
_megdnn_workspace workspace) | |||
: SizeArgs(opr, A.layout, B.layout, C.layout), | |||
tensor_a{A}, | |||
tensor_b{B}, | |||
tensor_c{C}, | |||
workspace{workspace} {} | |||
std::string MatrixMulForwardImpl::AlgoBase::SizeArgs::to_string() const { | |||
auto&& param = opr->param(); | |||
size_t m = layout_a.shape[0], n = layout_b.shape[1], | |||
k = layout_a.shape[param.transposeA ? 0 : 1]; | |||
MEGDNN_MARK_USED_VAR(m); | |||
MEGDNN_MARK_USED_VAR(n); | |||
MEGDNN_MARK_USED_VAR(k); | |||
return megdnn_mangle(ssprintf( | |||
"A={%zux%zu},B={%zux%zu},C={%zux%zu},Transpose A=%d,Transpose " | |||
"B=%d,ldA=%zu,ldB=%zu,ldC=%zu", | |||
m, k, k, n, m, n, param.transposeA, param.transposeB, | |||
layout_a.stride[0], layout_b.stride[0], layout_c.stride[0])); | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,118 @@ | |||
/** | |||
* \file dnn/src/rocm/matrix_mul/algos.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
#include "src/common/algo_base.h" | |||
#include "src/common/metahelper.h" | |||
#include "src/common/utils.h" | |||
#include "src/rocm/matrix_mul/opr_impl.h" | |||
#include <memory> | |||
#include <unordered_map> | |||
namespace megdnn { | |||
namespace rocm { | |||
/*! | |||
* \brief base class for matrix mul algos | |||
* | |||
*/ | |||
class MatrixMulForwardImpl::AlgoBase : public Algorithm { | |||
protected: | |||
~AlgoBase() = default; | |||
public: | |||
enum class AlgoType : uint32_t { | |||
ROCM_BLAS, | |||
}; | |||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; } | |||
struct SizeArgs { | |||
MatrixMulForwardImpl* opr; | |||
TensorLayout layout_a, layout_b, layout_c; | |||
std::string to_string() const; | |||
SizeArgs(MatrixMulForwardImpl* opr, const TensorLayout& A, | |||
const TensorLayout& B, const TensorLayout& C); | |||
bool can_be_treated_as_int8x8x32() const { | |||
return layout_a.dtype.enumv() == layout_b.dtype.enumv() && | |||
(layout_a.dtype.enumv() == DTypeEnum::Int8 || | |||
layout_a.dtype.enumv() == DTypeEnum::QuantizedS8) && | |||
(layout_c.dtype.enumv() == DTypeEnum::Int32 || | |||
layout_c.dtype.enumv() == DTypeEnum::QuantizedS32) && | |||
opr->param().format == param::MatrixMul::Format::DEFAULT; | |||
} | |||
}; | |||
struct ExecArgs : public SizeArgs { | |||
TensorND tensor_a, tensor_b, tensor_c; | |||
Workspace workspace; | |||
ExecArgs(MatrixMulForwardImpl* opr, _megdnn_tensor_in A, | |||
_megdnn_tensor_in B, _megdnn_tensor_out C, | |||
_megdnn_workspace workspace); | |||
}; | |||
virtual bool is_available(const SizeArgs& args) const = 0; | |||
virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||
virtual void exec(const ExecArgs& args) const = 0; | |||
bool is_available_wk(const SizeArgs& args, size_t limit) const { | |||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
} | |||
bool is_available_reproducible( | |||
const SizeArgs& args, bool reproducible = true, | |||
size_t limit = std::numeric_limits<size_t>::max()) const { | |||
return (!reproducible || is_reproducible()) && | |||
is_available_wk(args, limit); | |||
} | |||
AlgoBase& check_workspace(const SizeArgs& args, | |||
const Workspace& workspace) { | |||
auto req = get_workspace_in_bytes(args); | |||
megdnn_assert( | |||
req <= workspace.size, | |||
"matrix mul fwd algo %s: required workspace %zu bytes, got %zu", | |||
name(), req, workspace.size); | |||
return *this; | |||
} | |||
}; | |||
class MatrixMulForwardImpl::AlgoBlas final : public AlgoBase { | |||
public: | |||
AlgoBlas() = default; | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override { | |||
return 0_z; | |||
} | |||
const char* name() const override { return "BLAS"; } | |||
void exec(const ExecArgs& args) const override; | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(ROCM_BLAS) | |||
}; | |||
class MatrixMulForwardImpl::AlgoPack : NonCopyableObj { | |||
private: | |||
AlgoBase::Mapper m_all_algos_map; | |||
public: | |||
AlgoPack(); | |||
AlgoBlas blas; | |||
std::vector<AlgoBase*> all_algos; | |||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
}; | |||
} // namespace rocm | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,162 @@ | |||
/** | |||
* \file dnn/src/rocm/matrix_mul/Blas.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/rocm/matrix_mul/algos.h" | |||
#include "hcc_detail/hcc_defs_prologue.h" | |||
#include "src/rocm/handle.h" | |||
#include "src/rocm/utils.h" | |||
using namespace megdnn; | |||
using namespace rocm; | |||
bool MatrixMulForwardImpl::AlgoBlas::is_available( | |||
const SizeArgs& args) const { | |||
if (args.opr->param().format != param::MatrixMul::Format::DEFAULT) | |||
return false; | |||
if (args.layout_a.dtype == dtype::Float32() || | |||
args.layout_a.dtype == dtype::Float16()) { | |||
return true; | |||
} else if (args.layout_a.dtype.enumv() == DTypeEnum::Int8 || | |||
args.layout_a.dtype.enumv() == DTypeEnum::QuantizedS8) { | |||
auto k = args.layout_a.shape[args.opr->param().transposeA ? 0 : 1]; | |||
//! see | |||
//! https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/src/blas_ex/rocblas_gemm_ex.cpp:470 | |||
bool rocblas_int8x8x32_valid = true; | |||
rocblas_int8x8x32_valid &= (k % 4 == 0); | |||
rocblas_int8x8x32_valid &= (!args.opr->param().transposeB || | |||
args.layout_b.stride[0] % 4 == 0); | |||
rocblas_int8x8x32_valid &= (!args.opr->param().transposeA || | |||
args.layout_a.stride[0] % 4 == 0); | |||
return rocblas_int8x8x32_valid; | |||
} | |||
return false; | |||
} | |||
void MatrixMulForwardImpl::AlgoBlas::exec(const ExecArgs& args) const { | |||
auto m = args.layout_c.shape[0], n = args.layout_c.shape[1]; | |||
auto k = args.layout_a.shape[args.opr->param().transposeA ? 0 : 1]; | |||
auto&& handle = concrete_handle(args.opr->handle()); | |||
auto rocblas_handle_ = handle->get_rocblas_handle(); | |||
auto sgemm = [&]() { | |||
auto zero = handle->zero_device(); | |||
auto one = handle->one_device(); | |||
rocblas_check(rocblas_sgemm( | |||
rocblas_handle_, | |||
args.opr->param().transposeB ? rocblas_operation_transpose | |||
: rocblas_operation_none, | |||
args.opr->param().transposeA ? rocblas_operation_transpose | |||
: rocblas_operation_none, | |||
n, m, k, one, args.tensor_b.ptr<dt_float32>(), | |||
args.layout_b.stride[0], args.tensor_a.ptr<dt_float32>(), | |||
args.layout_a.stride[0], zero, args.tensor_c.ptr<dt_float32>(), | |||
args.layout_c.stride[0])); | |||
}; | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
//! used for FLOAT_IO16xC32, not tested | |||
auto gemm_ex = [&]() { | |||
auto zero = handle->zero_device(); | |||
auto one = handle->one_device(); | |||
//! These two arguments for future use, see | |||
//! https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/src/blas_ex/rocblas_gemm_ex.cpp | |||
int32_t solution_index = 0; | |||
uint32_t flags = 1; | |||
size_t ws_size = 0; | |||
auto gemm_ex_err = rocblas_gemm_ex( | |||
rocblas_handle_, | |||
args.opr->param().transposeB ? rocblas_operation_transpose | |||
: rocblas_operation_none, | |||
args.opr->param().transposeA ? rocblas_operation_transpose | |||
: rocblas_operation_none, | |||
n, m, k, one, args.tensor_b.raw_ptr, rocblas_datatype_f16_r, | |||
args.layout_b.stride[0], args.tensor_a.raw_ptr, | |||
rocblas_datatype_f16_r, args.layout_a.stride[0], zero, | |||
args.tensor_c.raw_ptr, rocblas_datatype_f16_r, | |||
args.layout_c.stride[0], args.tensor_c.raw_ptr, | |||
rocblas_datatype_f16_r, args.layout_c.stride[0], | |||
rocblas_datatype_f32_r, rocblas_gemm_algo_standard, | |||
solution_index, flags, &ws_size, nullptr); | |||
rocblas_check(gemm_ex_err); | |||
MEGDNN_MARK_USED_VAR(ws_size); | |||
}; | |||
auto hgemm = [&]() { | |||
auto one_half = handle->one_device_h(); | |||
auto zero_half = handle->zero_device_h(); | |||
auto hgemm_err = rocblas_hgemm( | |||
rocblas_handle_, | |||
args.opr->param().transposeB ? rocblas_operation_transpose | |||
: rocblas_operation_none, | |||
args.opr->param().transposeA ? rocblas_operation_transpose | |||
: rocblas_operation_none, | |||
n, m, k, reinterpret_cast<const rocblas_half*>(one_half), | |||
static_cast<const rocblas_half*>(args.tensor_b.raw_ptr), | |||
args.layout_b.stride[0], | |||
static_cast<const rocblas_half*>(args.tensor_a.raw_ptr), | |||
args.layout_a.stride[0], | |||
reinterpret_cast<const rocblas_half*>(zero_half), | |||
static_cast<rocblas_half*>(args.tensor_c.raw_ptr), | |||
args.layout_c.stride[0]); | |||
rocblas_check(hgemm_err); | |||
}; | |||
#endif | |||
if (args.opr->param().compute_mode == Param::ComputeMode::DEFAULT) { | |||
if (args.layout_a.dtype == dtype::Float32()) { | |||
sgemm(); | |||
} | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
else { | |||
megdnn_assert(args.layout_a.dtype == dtype::Float16(), | |||
"invalid matmul data type"); | |||
hgemm(); | |||
} | |||
#endif | |||
} | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
else if (args.opr->param().compute_mode == Param::ComputeMode::FLOAT32) { | |||
megdnn_assert(args.layout_b.dtype == dtype::Float16() && | |||
args.layout_c.dtype == dtype::Float16() && | |||
args.layout_a.dtype == dtype::Float16(), | |||
"DataType::FLOAT_IO16xC32 is supported, when dtype of A, " | |||
"B, C are all Float16"); | |||
gemm_ex(); | |||
} | |||
#endif | |||
else { | |||
megdnn_assert(args.can_be_treated_as_int8x8x32()); | |||
int32_t solution_index = 0; | |||
uint32_t flags = 1; | |||
size_t ws_size = 0; | |||
auto zero = handle->zero_device_i32(); | |||
auto one = handle->one_device_i32(); | |||
rocblas_check(rocblas_gemm_ex( | |||
rocblas_handle_, | |||
args.opr->param().transposeB ? rocblas_operation_transpose | |||
: rocblas_operation_none, | |||
args.opr->param().transposeA ? rocblas_operation_transpose | |||
: rocblas_operation_none, | |||
n, m, k, one, args.tensor_b.raw_ptr, rocblas_datatype_i8_r, | |||
args.layout_b.stride[0], args.tensor_a.raw_ptr, | |||
rocblas_datatype_i8_r, args.layout_a.stride[0], zero, | |||
args.tensor_c.raw_ptr, rocblas_datatype_i32_r, | |||
args.layout_c.stride[0], args.tensor_c.raw_ptr, | |||
rocblas_datatype_i32_r, args.layout_c.stride[0], | |||
rocblas_datatype_i32_r, rocblas_gemm_algo_standard, | |||
solution_index, flags, &ws_size, nullptr)); | |||
MEGDNN_MARK_USED_VAR(ws_size); | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -13,147 +13,53 @@ | |||
#include "src/rocm/utils.h" | |||
#include "src/rocm/handle.h" | |||
#include "./algos.h" | |||
#include "src/common/algo_chooser.h" | |||
namespace megdnn { | |||
namespace rocm { | |||
using namespace megdnn; | |||
using namespace rocm; | |||
void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, | |||
_megdnn_tensor_in B, | |||
_megdnn_tensor_out C, | |||
_megdnn_workspace workspace) | |||
{ | |||
check_exec(A.layout, B.layout, C.layout, workspace.size); | |||
auto m = C.layout.shape[0], n = C.layout.shape[1]; | |||
auto k = A.layout.shape[param().transposeA ? 0 : 1]; | |||
auto handle = concrete_handle(this->handle()); | |||
auto rocblas_handle_ = handle->get_rocblas_handle(); | |||
auto sgemm = [&]() { | |||
auto zero = handle->zero_device(); | |||
auto one = handle->one_device(); | |||
rocblas_check(rocblas_sgemm( | |||
rocblas_handle_, | |||
param().transposeB ? rocblas_operation_transpose | |||
: rocblas_operation_none, | |||
param().transposeA ? rocblas_operation_transpose | |||
: rocblas_operation_none, | |||
n, m, k, one, B.ptr<dt_float32>(), B.layout.stride[0], | |||
A.ptr<dt_float32>(), A.layout.stride[0], zero, | |||
C.ptr<dt_float32>(), C.layout.stride[0])); | |||
}; | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
//! used for FLOAT_IO16xC32, not tested | |||
auto gemm_ex = [&]() { | |||
auto zero = handle->zero_device(); | |||
auto one = handle->one_device(); | |||
//! These two arguments for future use, see | |||
//! https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/src/blas_ex/rocblas_gemm_ex.cpp | |||
int32_t solution_index = 0; | |||
uint32_t flags = 1; | |||
size_t ws_size = 0; | |||
auto gemm_ex_err = rocblas_gemm_ex( | |||
rocblas_handle_, | |||
param().transposeB ? rocblas_operation_transpose | |||
: rocblas_operation_none, | |||
param().transposeA ? rocblas_operation_transpose | |||
: rocblas_operation_none, | |||
n, m, k, one, B.raw_ptr, rocblas_datatype_f16_r, | |||
B.layout.stride[0], A.raw_ptr, rocblas_datatype_f16_r, | |||
A.layout.stride[0], zero, C.raw_ptr, rocblas_datatype_f16_r, | |||
C.layout.stride[0], C.raw_ptr, rocblas_datatype_f16_r, | |||
C.layout.stride[0], rocblas_datatype_f32_r, | |||
rocblas_gemm_algo_standard, solution_index, flags, &ws_size, | |||
nullptr); | |||
rocblas_check(gemm_ex_err); | |||
}; | |||
auto hgemm = [&]() { | |||
auto one_half = handle->one_device_h(); | |||
auto zero_half = handle->zero_device_h(); | |||
auto hgemm_err = rocblas_hgemm( | |||
rocblas_handle_, | |||
param().transposeB ? rocblas_operation_transpose | |||
: rocblas_operation_none, | |||
param().transposeA ? rocblas_operation_transpose | |||
: rocblas_operation_none, | |||
n, m, k, reinterpret_cast<const rocblas_half*>(one_half), | |||
static_cast<const rocblas_half*>(B.raw_ptr), B.layout.stride[0], | |||
static_cast<const rocblas_half*>(A.raw_ptr), A.layout.stride[0], | |||
reinterpret_cast<const rocblas_half*>(zero_half), | |||
static_cast<rocblas_half*>(C.raw_ptr), C.layout.stride[0]); | |||
rocblas_check(hgemm_err); | |||
}; | |||
#endif | |||
std::vector<MatrixMulForwardImpl::Algorithm*> | |||
MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, | |||
const TensorLayout& B, | |||
const TensorLayout& C) { | |||
AlgoBase::SizeArgs args{this, A, B, C}; | |||
return megdnn::get_all_algorithms<MatrixMulForwardImpl>(args); | |||
} | |||
if (param().compute_mode == Param::ComputeMode::DEFAULT) { | |||
if (A.layout.dtype == dtype::Float32()) { | |||
sgemm(); | |||
} | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
else { | |||
megdnn_assert(A.layout.dtype == dtype::Float16(), | |||
"invalid matmul data type"); | |||
hgemm(); | |||
} | |||
#endif | |||
MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | |||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||
size_t workspace_limit_in_bytes, bool reproducible) { | |||
AlgoBase::SizeArgs args{this, A, B, C}; | |||
if (sm_algo_pack.blas.is_available_reproducible( | |||
args, reproducible, workspace_limit_in_bytes)) { | |||
return &sm_algo_pack.blas; | |||
} | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
else if (param().compute_mode == Param::ComputeMode::FLOAT32) { | |||
megdnn_assert(B.layout.dtype == dtype::Float16() && | |||
C.layout.dtype == dtype::Float16() && | |||
A.layout.dtype == dtype::Float16(), | |||
"DataType::FLOAT_IO16xC32 is supported, when dtype of A, " | |||
"B, C are all Float16"); | |||
gemm_ex(); | |||
} | |||
#endif | |||
else if (A.layout.dtype == dtype::Int8() && | |||
B.layout.dtype == dtype::Int8() && | |||
C.layout.dtype == dtype::Int32()) { | |||
//! see | |||
//! https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/src/blas_ex/rocblas_gemm_ex.cpp:470 | |||
bool rocblas_int8x8x32_valid = true; | |||
rocblas_int8x8x32_valid &= (k % 4 == 0); | |||
rocblas_int8x8x32_valid &= | |||
(!param().transposeB || B.layout.stride[0] % 4 == 0); | |||
rocblas_int8x8x32_valid &= | |||
(!param().transposeA || A.layout.stride[0] % 4 == 0); | |||
megdnn_assert(rocblas_int8x8x32_valid, | |||
"rocblas int8x8x32 matmul requires K must be a multiple " | |||
"of 4, and/or LDA/LDB based on transpose mode" | |||
"get: %zu, is_trans_b = %d, %zu, is_trans_a = %d, %zu", | |||
k, param().transposeB, B.layout.stride[0], | |||
param().transposeA, A.layout.stride[0]); | |||
int32_t solution_index = 0; | |||
uint32_t flags = 1; | |||
size_t ws_size = 0; | |||
auto zero = handle->zero_device_i32(); | |||
auto one = handle->one_device_i32(); | |||
rocblas_check(rocblas_gemm_ex( | |||
rocblas_handle_, | |||
param().transposeB ? rocblas_operation_transpose | |||
: rocblas_operation_none, | |||
param().transposeA ? rocblas_operation_transpose | |||
: rocblas_operation_none, | |||
n, m, k, one, B.raw_ptr, rocblas_datatype_i8_r, | |||
B.layout.stride[0], A.raw_ptr, rocblas_datatype_i8_r, | |||
A.layout.stride[0], zero, C.raw_ptr, rocblas_datatype_i32_r, | |||
C.layout.stride[0], C.raw_ptr, rocblas_datatype_i32_r, | |||
C.layout.stride[0], rocblas_datatype_i32_r, | |||
rocblas_gemm_algo_standard, solution_index, flags, &ws_size, | |||
nullptr)); | |||
if (reproducible) { | |||
return megdnn::get_reproducible_algo<MatrixMulForwardImpl>( | |||
sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||
"matrix mul forward"); | |||
} else { | |||
megdnn_assert((A.layout.dtype == dtype::Int8() && | |||
B.layout.dtype == dtype::Int8() && | |||
C.layout.dtype == dtype::Int16()), | |||
"invalid matmul data type"); | |||
megdnn_throw("cuda matmul does not support INT8x8x16 now"); | |||
return megdnn::get_usable_algo<MatrixMulForwardImpl>( | |||
sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||
"matrix mul forward"); | |||
} | |||
} | |||
} // namespace rocm | |||
} // namespace megdnn | |||
size_t MatrixMulForwardImpl::get_workspace_in_bytes(const TensorLayout& A, | |||
const TensorLayout& B, | |||
const TensorLayout& C) { | |||
AlgoBase::SizeArgs args{this, A, B, C}; | |||
return megdnn::get_algorithm(this, A, B, C)->get_workspace_in_bytes(args); | |||
} | |||
void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||
_megdnn_tensor_out C, | |||
_megdnn_workspace workspace) { | |||
check_exec(A.layout, B.layout, C.layout, workspace.size); | |||
AlgoBase::ExecArgs args(this, A, B, C, workspace); | |||
auto&& algo = get_algorithm(this, A.layout, B.layout, C.layout); | |||
algo->check_workspace(args, workspace).exec(args); | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -20,29 +20,32 @@ public: | |||
void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | |||
_megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&) override { | |||
return 0; | |||
} | |||
const TensorLayout&) override; | |||
bool is_thread_safe() const override { return true; } | |||
class AlgoBase; | |||
class AlgoBlas; | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
private: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
const TensorLayout& /*C*/) override { | |||
return {}; | |||
} | |||
const TensorLayout& /*C*/) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, | |||
const TensorLayout& /*B*/, | |||
const TensorLayout& /*C*/, | |||
size_t /*workspace_limit_in_bytes*/, | |||
bool /*reproducible*/) override { | |||
return nullptr; | |||
} | |||
bool /*reproducible*/) override; | |||
const char* get_algorithm_set_name() const override { | |||
return "ROCM MATMUL"; | |||
} | |||
static AlgoPack sm_algo_pack; | |||
}; | |||
} // namespace rocm | |||
@@ -46,6 +46,37 @@ TEST_F(FALLBACK, MATRIX_MUL) { | |||
} | |||
} | |||
TEST_F(FALLBACK, MATRIX_MUL_NAIVE) { | |||
Checker<MatrixMul> checker(handle()); | |||
checker.set_before_exec_callback(AlgoChecker<MatrixMul>("FB_NAIVE")); | |||
using Param = MatrixMul::Param; | |||
auto args = matrix_mul::get_matmul_args(); | |||
for (auto arg : args) { | |||
auto m = arg.m, n = arg.n, k = arg.k; | |||
auto mask = arg.mask; | |||
Param param; | |||
param.transposeA = mask & 1; | |||
param.transposeB = mask & 2; | |||
TensorShape AS, BS, CS; | |||
if (param.transposeA) | |||
AS = TensorShape{k, m}; | |||
else | |||
AS = TensorShape{m, k}; | |||
if (param.transposeB) | |||
BS = TensorShape{n, k}; | |||
else | |||
BS = TensorShape{k, n}; | |||
CS = TensorShape{m, n}; | |||
TensorLayout AL, BL, CL; | |||
AL = TensorLayout(AS, dtype::Float32()); | |||
BL = TensorLayout(BS, dtype::Float32()); | |||
CL = TensorLayout(CS, dtype::Float32()); | |||
checker.set_param(param); | |||
checker.execl({AL, BL, CL}); | |||
} | |||
} | |||
TEST_F(FALLBACK, BATCHED_MATRIX_MUL) { | |||
Checker<BatchedMatrixMul> checker(handle()); | |||
@@ -232,7 +232,7 @@ TEST_F(NAIVE, MATRIX_MUL_QUANTIZEDS4_4x4x16) { | |||
2, 5, 3, 3, 7, 4, -7, 1, | |||
-5, 7, -4, -1, -1, 2, 4, 1, | |||
7, 2, -6, -2, -6, 3, 4, 4, | |||
-2, 2, 3, 0, 6, 5, 3, 4, | |||
-2, 2, 3, 0, 6, 5, 3, 4, | |||
-1, -1, -5, 5, 2, 5, 1, 4, | |||
6, 2, 0, 0, 3, 2, 2, 1, | |||
-4, -3, 7, 5, 0, 3, 2, 3}), | |||
@@ -243,7 +243,7 @@ TEST_F(NAIVE, MATRIX_MUL_QUANTIZEDS4_4x4x16) { | |||
3, -1, 2, 2, 7, 3, 6, 0, | |||
5, 4, 0, 2, 2, 3, 3, 2, | |||
1, -8, -7, -6, 0, -5, -4, 4, | |||
-3, 7, 1, 6, -2, 2, -1, 5, | |||
-3, 7, 1, 6, -2, 2, -1, 5, | |||
2, 0, 7, 6, 5, 4, 3, 2, | |||
0, 0, 1, 0, 5, 2, 2, 6}), | |||
{}}, | |||