GitOrigin-RevId: 5c143ab3ac
tags/v1.3.0
@@ -54,12 +54,7 @@ BatchedMatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||
all_algos.push_back(&cublasLt); | |||
#endif | |||
all_algos.push_back(&int8x8x32); | |||
for (auto& algo : mm_pack.all_algos) { | |||
brute_force_algos.emplace_back(AlgoBruteForce(algo)); | |||
} | |||
for (auto& algo : brute_force_algos) { | |||
all_algos.push_back(&algo); | |||
} | |||
all_algos.push_back(&brute_force); | |||
for (auto&& algo : all_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
@@ -87,26 +87,20 @@ public: | |||
class BatchedMatrixMulForwardImpl::AlgoBruteForce final | |||
: public BatchedMatrixMulForwardImpl::AlgoBase { | |||
using Param = MatrixMulForward::Param; | |||
private: | |||
std::string m_name; | |||
MatrixMulForwardImpl::AlgoBase* m_algorithm = nullptr; | |||
WorkspaceBundle get_workspace_bundle(); | |||
public: | |||
AlgoBruteForce(MatrixMulForwardImpl::AlgoBase* algo); | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override; | |||
void exec(const ExecArgs& args) const final; | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return m_name.c_str(); } | |||
const char* name() const override { return "BRUTE_FORCE"; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_BRUTE_FORCE) | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_algorithm, ret); | |||
return ret; | |||
} | |||
std::vector<SearchItem> get_subopr_list( | |||
const TensorLayoutArray& layouts, | |||
const OperatorBase* opr) const override; | |||
}; | |||
class BatchedMatrixMulForwardImpl::AlgoCublas final | |||
: public BatchedMatrixMulForwardImpl::AlgoBase { | |||
@@ -157,7 +151,7 @@ public: | |||
#endif | |||
AlgoInt8x8x32 int8x8x32; | |||
std::vector<AlgoBase*> all_algos; | |||
std::vector<AlgoBruteForce> brute_force_algos; | |||
AlgoBruteForce brute_force; | |||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
}; | |||
@@ -9,48 +9,86 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "./algo.h" | |||
#include "megdnn/opr_param_defs.h" | |||
#include "src/common/algo_chooser.h" | |||
#include "src/cuda/handle.h" | |||
#include "src/cuda/utils.h" | |||
using namespace megdnn; | |||
using namespace cuda; | |||
BatchedMatrixMulForwardImpl::AlgoBruteForce::AlgoBruteForce( | |||
MatrixMulForwardImpl::AlgoBase* algo) | |||
: m_algorithm(algo) { | |||
m_name = ssprintf("BRUTE_FORCE-%s", algo->name()); | |||
namespace { | |||
std::pair<TensorLayoutArray, MatrixMulForward::Param> sub_opr_config( | |||
const TensorLayout& layout_a, const TensorLayout& layout_b, | |||
const TensorLayout& layout_c, const BatchedMatrixMulForward* opr) { | |||
auto mm_layout_a = layout_a.remove_axis(0); | |||
auto mm_layout_b = layout_b.remove_axis(0); | |||
auto mm_layout_c = layout_c.remove_axis(0); | |||
return {{mm_layout_a, mm_layout_b, mm_layout_c}, opr->param()}; | |||
} | |||
} // namespace | |||
std::vector<Algorithm::SearchItem> | |||
BatchedMatrixMulForwardImpl::AlgoBruteForce::get_subopr_list( | |||
const TensorLayoutArray& layouts, const OperatorBase* opr) const { | |||
const BatchedMatrixMulForwardImpl* bmm_opr = | |||
static_cast<const BatchedMatrixMulForwardImpl*>(opr); | |||
auto&& config = sub_opr_config(layouts[0], layouts[1], layouts[2], bmm_opr); | |||
std::string param_str; | |||
Algorithm::serialize_write_pod(config.second, param_str); | |||
return {{Algorithm::OprType::MATRIX_MUL_FORWARD, param_str, config.first}}; | |||
} | |||
bool BatchedMatrixMulForwardImpl::AlgoBruteForce::is_available( | |||
const SizeArgs& args) const { | |||
MatrixMulForwardImpl mm{args.opr->handle()}; | |||
mm.param() = {args.opr->param().transposeA, args.opr->param().transposeB}; | |||
mm.execution_policy() = {m_algorithm->desc(), {}}; | |||
auto matmul_opr = args.opr->handle()->create_operator<MatrixMulForward>(); | |||
if (args.opr->execution_policy().algo.valid() && | |||
!args.opr->execution_policy().sub_policy.empty()) { | |||
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1); | |||
matmul_opr->execution_policy() = | |||
args.opr->execution_policy().sub_policy[0]; | |||
} | |||
auto mm_layout_a = args.layout_a.remove_axis(0); | |||
auto mm_layout_b = args.layout_b.remove_axis(0); | |||
auto mm_layout_c = args.layout_c.remove_axis(0); | |||
auto&& config = sub_opr_config(args.layout_a, args.layout_b, args.layout_c, | |||
args.opr); | |||
matmul_opr->param() = config.second; | |||
MatrixMulForwardImpl::AlgoBase::SizeArgs mm_args{&mm, mm_layout_a, | |||
mm_layout_b, mm_layout_c}; | |||
return m_algorithm->is_available(mm_args); | |||
return get_algorithm(static_cast<MatrixMulForwardImpl*>(matmul_opr.get()), | |||
config.first[0], config.first[1], config.first[2]); | |||
} | |||
size_t BatchedMatrixMulForwardImpl::AlgoBruteForce::get_workspace_in_bytes( | |||
const SizeArgs& args) const { | |||
auto mm_opr = args.opr->handle()->create_operator<MatrixMulForward>(); | |||
mm_opr->param() = {args.opr->param().transposeA, | |||
args.opr->param().transposeB}; | |||
mm_opr->execution_policy() = {m_algorithm->desc(), {}}; | |||
auto matmul_opr = args.opr->handle()->create_operator<MatrixMulForward>(); | |||
if (args.opr->execution_policy().algo.valid() && | |||
!args.opr->execution_policy().sub_policy.empty()) { | |||
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1); | |||
matmul_opr->execution_policy() = | |||
args.opr->execution_policy().sub_policy[0]; | |||
} | |||
auto&& config = sub_opr_config(args.layout_a, args.layout_b, args.layout_c, | |||
args.opr); | |||
matmul_opr->param() = config.second; | |||
return mm_opr->get_workspace_in_bytes(args.layout_a, args.layout_b, | |||
args.layout_c); | |||
return matmul_opr->get_workspace_in_bytes(config.first[0], config.first[1], | |||
config.first[2]); | |||
} | |||
void BatchedMatrixMulForwardImpl::AlgoBruteForce::exec( | |||
const ExecArgs& args) const { | |||
auto N = args.layout_a.shape[0]; | |||
auto&& mm_opr = args.opr->handle()->create_operator<MatrixMulForward>(); | |||
mm_opr->param() = {args.opr->param().transposeA, | |||
args.opr->param().transposeB}; | |||
mm_opr->execution_policy() = {m_algorithm->desc(), {}}; | |||
auto matmul_opr = args.opr->handle()->create_operator<MatrixMulForward>(); | |||
if (args.opr->execution_policy().algo.valid()) { | |||
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1); | |||
matmul_opr->execution_policy() = | |||
args.opr->execution_policy().sub_policy[0]; | |||
} | |||
auto&& config = sub_opr_config(args.layout_a, args.layout_b, args.layout_c, | |||
args.opr); | |||
matmul_opr->param() = config.second; | |||
rep(n, N) { | |||
TensorND A_, B_, C_; | |||
auto tensor_n_from_batch = [n](const TensorND& in, TensorND& out) { | |||
@@ -62,6 +100,6 @@ void BatchedMatrixMulForwardImpl::AlgoBruteForce::exec( | |||
tensor_n_from_batch(args.tensor_a, A_); | |||
tensor_n_from_batch(args.tensor_b, B_); | |||
tensor_n_from_batch(args.tensor_c, C_); | |||
mm_opr->exec(A_, B_, C_, args.workspace); | |||
matmul_opr->exec(A_, B_, C_, args.workspace); | |||
} | |||
} |
@@ -56,9 +56,8 @@ std::vector<Algorithm*> BatchedMatrixMulForwardImpl::get_all_algorithms( | |||
Algorithm* BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | |||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||
size_t workspace_limit_in_bytes, bool reproducible) { | |||
MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes); | |||
AlgoBase::SizeArgs args(this, A, B, C); | |||
std::vector<AlgoBase*> brute_force_algos; | |||
if (sm_algo_pack.cublas.is_available_reproducible(args, reproducible)) { | |||
return &sm_algo_pack.cublas; | |||
} | |||
@@ -72,25 +71,14 @@ Algorithm* BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | |||
reproducible)) { | |||
return &sm_algo_pack.int8x8x32; | |||
} else { | |||
for (auto& algo : sm_algo_pack.brute_force_algos) { | |||
if (algo.is_available_reproducible(args, reproducible)) { | |||
return &algo; | |||
} | |||
if (sm_algo_pack.brute_force.is_available_reproducible(args, | |||
reproducible)) { | |||
return &sm_algo_pack.brute_force; | |||
} | |||
} | |||
for (auto& algo : sm_algo_pack.brute_force_algos) | |||
brute_force_algos.push_back(&algo); | |||
if (reproducible) { | |||
return megdnn::get_reproducible_algo<BatchedMatrixMulForwardImpl>( | |||
brute_force_algos, args, workspace_limit_in_bytes, | |||
"batched matrix mul"); | |||
} else { | |||
return megdnn::get_usable_algo<BatchedMatrixMulForwardImpl>( | |||
brute_force_algos, args, workspace_limit_in_bytes, | |||
"batched matrix mul"); | |||
} | |||
megdnn_throw("No usable algo for batched_matrix_mul"); | |||
return nullptr; | |||
}; | |||
// vim: syntax=cpp.doxygen |
@@ -138,12 +138,13 @@ std::vector<matrix_mul::TestArg> matrix_mul::get_batched_matmul_args() { | |||
template <typename Opr> | |||
void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, | |||
Handle* handle, const char* algo, | |||
Handle* handle, | |||
const ExecutionPolicyAlgoName& algo, | |||
param::MatrixMul::Format format, size_t nbase, | |||
float eps, std::vector<TestArg>&& user_args) { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv()); | |||
Checker<Opr> checker(handle); | |||
if (algo) { | |||
if (!algo.name.empty()) { | |||
checker.set_before_exec_callback(AlgoChecker<Opr>(algo)); | |||
} | |||
std::unique_ptr<RNG> rng; | |||
@@ -267,7 +268,8 @@ void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, | |||
void matrix_mul::check_batched_matrix_mul(DType A_dtype, DType B_dtype, | |||
DType C_dtype, Handle* handle, | |||
const char* algo, float eps, | |||
const ExecutionPolicyAlgoName& algo, | |||
float eps, | |||
std::vector<TestArg>&& args) { | |||
check_matrix_mul<megdnn::BatchedMatrixMul>( | |||
A_dtype, B_dtype, C_dtype, handle, algo, | |||
@@ -276,7 +278,8 @@ void matrix_mul::check_batched_matrix_mul(DType A_dtype, DType B_dtype, | |||
} | |||
void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, | |||
Handle* handle, const char* algo, | |||
Handle* handle, | |||
const ExecutionPolicyAlgoName& algo, | |||
param::MatrixMul::Format format, size_t nbase, | |||
float eps) { | |||
check_matrix_mul<megdnn::MatrixMul>(A_dtype, B_dtype, C_dtype, handle, algo, | |||
@@ -16,6 +16,7 @@ | |||
#include "megdnn/handle.h" | |||
#include "megdnn/opr_param_defs.h" | |||
#include "megdnn/oprs.h" | |||
#include "test/common/checker.h" | |||
namespace megdnn { | |||
namespace test { | |||
@@ -58,18 +59,19 @@ using TestArgFilterFunc = std::function<bool(const TestArg&)>; | |||
template <typename Opr = megdnn::MatrixMul> | |||
void check_matrix_mul( | |||
DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, | |||
const char* algo = nullptr, | |||
const ExecutionPolicyAlgoName& algo = {"", {}}, | |||
param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT, | |||
size_t nbase = 8, float eps = 1e-3, std::vector<TestArg>&& args = {}); | |||
void check_matrix_mul( | |||
DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, | |||
const char* algo = nullptr, | |||
const ExecutionPolicyAlgoName& algo = {"", {}}, | |||
param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT, | |||
size_t nbase = 8, float eps = 1e-3); | |||
void check_batched_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, | |||
Handle* handle, const char* algo = nullptr, | |||
Handle* handle, | |||
const ExecutionPolicyAlgoName& algo = {"", {}}, | |||
float eps = 1e-3, | |||
std::vector<TestArg>&& args = {}); | |||
@@ -20,8 +20,8 @@ using namespace test; | |||
//! check batch=1 and batch_stride is arbitrarily | |||
TEST_F(CPU, BATCHED_MATRIX_MUL_BATCH_1) { | |||
matrix_mul::check_batched_matrix_mul( | |||
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), | |||
nullptr, 1e-3, | |||
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "", | |||
1e-3, | |||
std::vector<matrix_mul::TestArg>{ | |||
{5, 5, 5, 0, 5, 5, 5, 1, 5, 5, 5}}); | |||
} | |||
@@ -62,6 +62,34 @@ TEST_F(CUDA, BATCHED_MATRIX_MUL_LT_F32_PART4) { | |||
#undef F32_TEST_PART | |||
TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BRUTE_FORCE_PART1) { | |||
matrix_mul::check_batched_matrix_mul( | |||
dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(), | |||
ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}, 1e-3, | |||
matrix_mul::get_batched_matmul_args_mask(0)); | |||
} | |||
TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BRUTE_FORCE_PART2) { | |||
matrix_mul::check_batched_matrix_mul( | |||
dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(), | |||
ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}, 1e-3, | |||
matrix_mul::get_batched_matmul_args_mask(1)); | |||
} | |||
TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BRUTE_FORCE_PART3) { | |||
matrix_mul::check_batched_matrix_mul( | |||
dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(), | |||
ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}, 1e-3, | |||
matrix_mul::get_batched_matmul_args_mask(2)); | |||
} | |||
TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BRUTE_FORCE_PART4) { | |||
matrix_mul::check_batched_matrix_mul( | |||
dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(), | |||
ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}, 1e-3, | |||
matrix_mul::get_batched_matmul_args_mask(3)); | |||
} | |||
TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_PART1) { | |||
require_compute_capability(6, 0); | |||
matrix_mul::check_batched_matrix_mul( | |||
@@ -150,7 +178,8 @@ TEST_F(CUDA, BATCHED_MATRIX_MUL_INT8x8x32) { | |||
TEST_F(CUDA, BATCHED_MATMUL_8x8x32_BENCHMARK) { | |||
require_compute_capability(6, 1); | |||
auto run = [&](bool transA, bool transB, size_t m, size_t n, size_t k, | |||
const char* algo1, const char* algo2, size_t b = 128) { | |||
const ExecutionPolicyAlgoName& algo1, | |||
const ExecutionPolicyAlgoName& algo2, size_t b = 128) { | |||
size_t RUNS = 10; | |||
CUBenchmarker<BatchedMatrixMul> bencher1(handle_cuda()); | |||
bencher1.set_display(false).set_times(RUNS); | |||
@@ -196,19 +225,20 @@ TEST_F(CUDA, BATCHED_MATMUL_8x8x32_BENCHMARK) { | |||
printf("trA: %d, trB: %d, m: %ld, n: %ld, k: %ld, b: %ld, speedup: %s " | |||
"/ " | |||
"%s %.3f\n", | |||
transA, transB, m, n, k, b, algo1, algo2, flops1 / flops2); | |||
transA, transB, m, n, k, b, algo1.name.c_str(), | |||
algo2.name.c_str(), flops1 / flops2); | |||
}; | |||
for (bool transA : {0, 1}) | |||
for (bool transB : {0, 1}) { | |||
run(transA, transB, 128, 576, 128, "INT8x8x32", | |||
"BRUTE_FORCE-CUBLAS"); | |||
ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}); | |||
run(transA, transB, 256, 144, 256, "INT8x8x32", | |||
"BRUTE_FORCE-CUBLAS"); | |||
ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}); | |||
run(transA, transB, 512, 36, 512, "INT8x8x32", | |||
"BRUTE_FORCE-CUBLAS"); | |||
ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}); | |||
run(transA, transB, 1024, 8, 1024, "INT8x8x32", | |||
"BRUTE_FORCE-CUBLAS"); | |||
ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}); | |||
} | |||
} | |||
#endif | |||