GitOrigin-RevId: 5c143ab3ac
tags/v1.3.0
@@ -54,12 +54,7 @@ BatchedMatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||||
all_algos.push_back(&cublasLt); | all_algos.push_back(&cublasLt); | ||||
#endif | #endif | ||||
all_algos.push_back(&int8x8x32); | all_algos.push_back(&int8x8x32); | ||||
for (auto& algo : mm_pack.all_algos) { | |||||
brute_force_algos.emplace_back(AlgoBruteForce(algo)); | |||||
} | |||||
for (auto& algo : brute_force_algos) { | |||||
all_algos.push_back(&algo); | |||||
} | |||||
all_algos.push_back(&brute_force); | |||||
for (auto&& algo : all_algos) { | for (auto&& algo : all_algos) { | ||||
m_all_algos_map.emplace(algo->info().desc, algo); | m_all_algos_map.emplace(algo->info().desc, algo); | ||||
@@ -87,26 +87,20 @@ public: | |||||
class BatchedMatrixMulForwardImpl::AlgoBruteForce final | class BatchedMatrixMulForwardImpl::AlgoBruteForce final | ||||
: public BatchedMatrixMulForwardImpl::AlgoBase { | : public BatchedMatrixMulForwardImpl::AlgoBase { | ||||
using Param = MatrixMulForward::Param; | using Param = MatrixMulForward::Param; | ||||
private: | private: | ||||
std::string m_name; | |||||
MatrixMulForwardImpl::AlgoBase* m_algorithm = nullptr; | |||||
WorkspaceBundle get_workspace_bundle(); | WorkspaceBundle get_workspace_bundle(); | ||||
public: | public: | ||||
AlgoBruteForce(MatrixMulForwardImpl::AlgoBase* algo); | |||||
bool is_available(const SizeArgs& args) const override; | bool is_available(const SizeArgs& args) const override; | ||||
size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override; | size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override; | ||||
void exec(const ExecArgs& args) const final; | void exec(const ExecArgs& args) const final; | ||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { return m_name.c_str(); } | |||||
const char* name() const override { return "BRUTE_FORCE"; } | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_BRUTE_FORCE) | MEGDNN_DECL_ALGO_TYPE(CUDA_BRUTE_FORCE) | ||||
std::string param() const override { | |||||
std::string ret; | |||||
serialize_write_pod(m_algorithm, ret); | |||||
return ret; | |||||
} | |||||
std::vector<SearchItem> get_subopr_list( | |||||
const TensorLayoutArray& layouts, | |||||
const OperatorBase* opr) const override; | |||||
}; | }; | ||||
class BatchedMatrixMulForwardImpl::AlgoCublas final | class BatchedMatrixMulForwardImpl::AlgoCublas final | ||||
: public BatchedMatrixMulForwardImpl::AlgoBase { | : public BatchedMatrixMulForwardImpl::AlgoBase { | ||||
@@ -157,7 +151,7 @@ public: | |||||
#endif | #endif | ||||
AlgoInt8x8x32 int8x8x32; | AlgoInt8x8x32 int8x8x32; | ||||
std::vector<AlgoBase*> all_algos; | std::vector<AlgoBase*> all_algos; | ||||
std::vector<AlgoBruteForce> brute_force_algos; | |||||
AlgoBruteForce brute_force; | |||||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | ||||
}; | }; | ||||
@@ -9,48 +9,86 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "./algo.h" | #include "./algo.h" | ||||
#include "megdnn/opr_param_defs.h" | |||||
#include "src/common/algo_chooser.h" | |||||
#include "src/cuda/handle.h" | #include "src/cuda/handle.h" | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
BatchedMatrixMulForwardImpl::AlgoBruteForce::AlgoBruteForce( | |||||
MatrixMulForwardImpl::AlgoBase* algo) | |||||
: m_algorithm(algo) { | |||||
m_name = ssprintf("BRUTE_FORCE-%s", algo->name()); | |||||
namespace { | |||||
std::pair<TensorLayoutArray, MatrixMulForward::Param> sub_opr_config( | |||||
const TensorLayout& layout_a, const TensorLayout& layout_b, | |||||
const TensorLayout& layout_c, const BatchedMatrixMulForward* opr) { | |||||
auto mm_layout_a = layout_a.remove_axis(0); | |||||
auto mm_layout_b = layout_b.remove_axis(0); | |||||
auto mm_layout_c = layout_c.remove_axis(0); | |||||
return {{mm_layout_a, mm_layout_b, mm_layout_c}, opr->param()}; | |||||
} | |||||
} // namespace | |||||
std::vector<Algorithm::SearchItem> | |||||
BatchedMatrixMulForwardImpl::AlgoBruteForce::get_subopr_list( | |||||
const TensorLayoutArray& layouts, const OperatorBase* opr) const { | |||||
const BatchedMatrixMulForwardImpl* bmm_opr = | |||||
static_cast<const BatchedMatrixMulForwardImpl*>(opr); | |||||
auto&& config = sub_opr_config(layouts[0], layouts[1], layouts[2], bmm_opr); | |||||
std::string param_str; | |||||
Algorithm::serialize_write_pod(config.second, param_str); | |||||
return {{Algorithm::OprType::MATRIX_MUL_FORWARD, param_str, config.first}}; | |||||
} | } | ||||
bool BatchedMatrixMulForwardImpl::AlgoBruteForce::is_available( | bool BatchedMatrixMulForwardImpl::AlgoBruteForce::is_available( | ||||
const SizeArgs& args) const { | const SizeArgs& args) const { | ||||
MatrixMulForwardImpl mm{args.opr->handle()}; | |||||
mm.param() = {args.opr->param().transposeA, args.opr->param().transposeB}; | |||||
mm.execution_policy() = {m_algorithm->desc(), {}}; | |||||
auto matmul_opr = args.opr->handle()->create_operator<MatrixMulForward>(); | |||||
if (args.opr->execution_policy().algo.valid() && | |||||
!args.opr->execution_policy().sub_policy.empty()) { | |||||
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1); | |||||
matmul_opr->execution_policy() = | |||||
args.opr->execution_policy().sub_policy[0]; | |||||
} | |||||
auto mm_layout_a = args.layout_a.remove_axis(0); | |||||
auto mm_layout_b = args.layout_b.remove_axis(0); | |||||
auto mm_layout_c = args.layout_c.remove_axis(0); | |||||
auto&& config = sub_opr_config(args.layout_a, args.layout_b, args.layout_c, | |||||
args.opr); | |||||
matmul_opr->param() = config.second; | |||||
MatrixMulForwardImpl::AlgoBase::SizeArgs mm_args{&mm, mm_layout_a, | |||||
mm_layout_b, mm_layout_c}; | |||||
return m_algorithm->is_available(mm_args); | |||||
return get_algorithm(static_cast<MatrixMulForwardImpl*>(matmul_opr.get()), | |||||
config.first[0], config.first[1], config.first[2]); | |||||
} | } | ||||
size_t BatchedMatrixMulForwardImpl::AlgoBruteForce::get_workspace_in_bytes( | size_t BatchedMatrixMulForwardImpl::AlgoBruteForce::get_workspace_in_bytes( | ||||
const SizeArgs& args) const { | const SizeArgs& args) const { | ||||
auto mm_opr = args.opr->handle()->create_operator<MatrixMulForward>(); | |||||
mm_opr->param() = {args.opr->param().transposeA, | |||||
args.opr->param().transposeB}; | |||||
mm_opr->execution_policy() = {m_algorithm->desc(), {}}; | |||||
auto matmul_opr = args.opr->handle()->create_operator<MatrixMulForward>(); | |||||
if (args.opr->execution_policy().algo.valid() && | |||||
!args.opr->execution_policy().sub_policy.empty()) { | |||||
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1); | |||||
matmul_opr->execution_policy() = | |||||
args.opr->execution_policy().sub_policy[0]; | |||||
} | |||||
auto&& config = sub_opr_config(args.layout_a, args.layout_b, args.layout_c, | |||||
args.opr); | |||||
matmul_opr->param() = config.second; | |||||
return mm_opr->get_workspace_in_bytes(args.layout_a, args.layout_b, | |||||
args.layout_c); | |||||
return matmul_opr->get_workspace_in_bytes(config.first[0], config.first[1], | |||||
config.first[2]); | |||||
} | } | ||||
void BatchedMatrixMulForwardImpl::AlgoBruteForce::exec( | void BatchedMatrixMulForwardImpl::AlgoBruteForce::exec( | ||||
const ExecArgs& args) const { | const ExecArgs& args) const { | ||||
auto N = args.layout_a.shape[0]; | auto N = args.layout_a.shape[0]; | ||||
auto&& mm_opr = args.opr->handle()->create_operator<MatrixMulForward>(); | |||||
mm_opr->param() = {args.opr->param().transposeA, | |||||
args.opr->param().transposeB}; | |||||
mm_opr->execution_policy() = {m_algorithm->desc(), {}}; | |||||
auto matmul_opr = args.opr->handle()->create_operator<MatrixMulForward>(); | |||||
if (args.opr->execution_policy().algo.valid()) { | |||||
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1); | |||||
matmul_opr->execution_policy() = | |||||
args.opr->execution_policy().sub_policy[0]; | |||||
} | |||||
auto&& config = sub_opr_config(args.layout_a, args.layout_b, args.layout_c, | |||||
args.opr); | |||||
matmul_opr->param() = config.second; | |||||
rep(n, N) { | rep(n, N) { | ||||
TensorND A_, B_, C_; | TensorND A_, B_, C_; | ||||
auto tensor_n_from_batch = [n](const TensorND& in, TensorND& out) { | auto tensor_n_from_batch = [n](const TensorND& in, TensorND& out) { | ||||
@@ -62,6 +100,6 @@ void BatchedMatrixMulForwardImpl::AlgoBruteForce::exec( | |||||
tensor_n_from_batch(args.tensor_a, A_); | tensor_n_from_batch(args.tensor_a, A_); | ||||
tensor_n_from_batch(args.tensor_b, B_); | tensor_n_from_batch(args.tensor_b, B_); | ||||
tensor_n_from_batch(args.tensor_c, C_); | tensor_n_from_batch(args.tensor_c, C_); | ||||
mm_opr->exec(A_, B_, C_, args.workspace); | |||||
matmul_opr->exec(A_, B_, C_, args.workspace); | |||||
} | } | ||||
} | } |
@@ -56,9 +56,8 @@ std::vector<Algorithm*> BatchedMatrixMulForwardImpl::get_all_algorithms( | |||||
Algorithm* BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | Algorithm* BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | ||||
size_t workspace_limit_in_bytes, bool reproducible) { | size_t workspace_limit_in_bytes, bool reproducible) { | ||||
MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes); | |||||
AlgoBase::SizeArgs args(this, A, B, C); | AlgoBase::SizeArgs args(this, A, B, C); | ||||
std::vector<AlgoBase*> brute_force_algos; | |||||
if (sm_algo_pack.cublas.is_available_reproducible(args, reproducible)) { | if (sm_algo_pack.cublas.is_available_reproducible(args, reproducible)) { | ||||
return &sm_algo_pack.cublas; | return &sm_algo_pack.cublas; | ||||
} | } | ||||
@@ -72,25 +71,14 @@ Algorithm* BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | |||||
reproducible)) { | reproducible)) { | ||||
return &sm_algo_pack.int8x8x32; | return &sm_algo_pack.int8x8x32; | ||||
} else { | } else { | ||||
for (auto& algo : sm_algo_pack.brute_force_algos) { | |||||
if (algo.is_available_reproducible(args, reproducible)) { | |||||
return &algo; | |||||
} | |||||
if (sm_algo_pack.brute_force.is_available_reproducible(args, | |||||
reproducible)) { | |||||
return &sm_algo_pack.brute_force; | |||||
} | } | ||||
} | } | ||||
for (auto& algo : sm_algo_pack.brute_force_algos) | |||||
brute_force_algos.push_back(&algo); | |||||
if (reproducible) { | |||||
return megdnn::get_reproducible_algo<BatchedMatrixMulForwardImpl>( | |||||
brute_force_algos, args, workspace_limit_in_bytes, | |||||
"batched matrix mul"); | |||||
} else { | |||||
return megdnn::get_usable_algo<BatchedMatrixMulForwardImpl>( | |||||
brute_force_algos, args, workspace_limit_in_bytes, | |||||
"batched matrix mul"); | |||||
} | |||||
megdnn_throw("No usable algo for batched_matrix_mul"); | |||||
return nullptr; | |||||
}; | }; | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -138,12 +138,13 @@ std::vector<matrix_mul::TestArg> matrix_mul::get_batched_matmul_args() { | |||||
template <typename Opr> | template <typename Opr> | ||||
void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, | void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, | ||||
Handle* handle, const char* algo, | |||||
Handle* handle, | |||||
const ExecutionPolicyAlgoName& algo, | |||||
param::MatrixMul::Format format, size_t nbase, | param::MatrixMul::Format format, size_t nbase, | ||||
float eps, std::vector<TestArg>&& user_args) { | float eps, std::vector<TestArg>&& user_args) { | ||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv()); | megdnn_assert(A_dtype.enumv() == B_dtype.enumv()); | ||||
Checker<Opr> checker(handle); | Checker<Opr> checker(handle); | ||||
if (algo) { | |||||
if (!algo.name.empty()) { | |||||
checker.set_before_exec_callback(AlgoChecker<Opr>(algo)); | checker.set_before_exec_callback(AlgoChecker<Opr>(algo)); | ||||
} | } | ||||
std::unique_ptr<RNG> rng; | std::unique_ptr<RNG> rng; | ||||
@@ -267,7 +268,8 @@ void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, | |||||
void matrix_mul::check_batched_matrix_mul(DType A_dtype, DType B_dtype, | void matrix_mul::check_batched_matrix_mul(DType A_dtype, DType B_dtype, | ||||
DType C_dtype, Handle* handle, | DType C_dtype, Handle* handle, | ||||
const char* algo, float eps, | |||||
const ExecutionPolicyAlgoName& algo, | |||||
float eps, | |||||
std::vector<TestArg>&& args) { | std::vector<TestArg>&& args) { | ||||
check_matrix_mul<megdnn::BatchedMatrixMul>( | check_matrix_mul<megdnn::BatchedMatrixMul>( | ||||
A_dtype, B_dtype, C_dtype, handle, algo, | A_dtype, B_dtype, C_dtype, handle, algo, | ||||
@@ -276,7 +278,8 @@ void matrix_mul::check_batched_matrix_mul(DType A_dtype, DType B_dtype, | |||||
} | } | ||||
void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, | void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, | ||||
Handle* handle, const char* algo, | |||||
Handle* handle, | |||||
const ExecutionPolicyAlgoName& algo, | |||||
param::MatrixMul::Format format, size_t nbase, | param::MatrixMul::Format format, size_t nbase, | ||||
float eps) { | float eps) { | ||||
check_matrix_mul<megdnn::MatrixMul>(A_dtype, B_dtype, C_dtype, handle, algo, | check_matrix_mul<megdnn::MatrixMul>(A_dtype, B_dtype, C_dtype, handle, algo, | ||||
@@ -16,6 +16,7 @@ | |||||
#include "megdnn/handle.h" | #include "megdnn/handle.h" | ||||
#include "megdnn/opr_param_defs.h" | #include "megdnn/opr_param_defs.h" | ||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
#include "test/common/checker.h" | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace test { | namespace test { | ||||
@@ -58,18 +59,19 @@ using TestArgFilterFunc = std::function<bool(const TestArg&)>; | |||||
template <typename Opr = megdnn::MatrixMul> | template <typename Opr = megdnn::MatrixMul> | ||||
void check_matrix_mul( | void check_matrix_mul( | ||||
DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, | DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, | ||||
const char* algo = nullptr, | |||||
const ExecutionPolicyAlgoName& algo = {"", {}}, | |||||
param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT, | param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT, | ||||
size_t nbase = 8, float eps = 1e-3, std::vector<TestArg>&& args = {}); | size_t nbase = 8, float eps = 1e-3, std::vector<TestArg>&& args = {}); | ||||
void check_matrix_mul( | void check_matrix_mul( | ||||
DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, | DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, | ||||
const char* algo = nullptr, | |||||
const ExecutionPolicyAlgoName& algo = {"", {}}, | |||||
param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT, | param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT, | ||||
size_t nbase = 8, float eps = 1e-3); | size_t nbase = 8, float eps = 1e-3); | ||||
void check_batched_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, | void check_batched_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, | ||||
Handle* handle, const char* algo = nullptr, | |||||
Handle* handle, | |||||
const ExecutionPolicyAlgoName& algo = {"", {}}, | |||||
float eps = 1e-3, | float eps = 1e-3, | ||||
std::vector<TestArg>&& args = {}); | std::vector<TestArg>&& args = {}); | ||||
@@ -20,8 +20,8 @@ using namespace test; | |||||
//! check batch=1 and batch_stride is arbitrarily | //! check batch=1 and batch_stride is arbitrarily | ||||
TEST_F(CPU, BATCHED_MATRIX_MUL_BATCH_1) { | TEST_F(CPU, BATCHED_MATRIX_MUL_BATCH_1) { | ||||
matrix_mul::check_batched_matrix_mul( | matrix_mul::check_batched_matrix_mul( | ||||
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), | |||||
nullptr, 1e-3, | |||||
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "", | |||||
1e-3, | |||||
std::vector<matrix_mul::TestArg>{ | std::vector<matrix_mul::TestArg>{ | ||||
{5, 5, 5, 0, 5, 5, 5, 1, 5, 5, 5}}); | {5, 5, 5, 0, 5, 5, 5, 1, 5, 5, 5}}); | ||||
} | } | ||||
@@ -62,6 +62,34 @@ TEST_F(CUDA, BATCHED_MATRIX_MUL_LT_F32_PART4) { | |||||
#undef F32_TEST_PART | #undef F32_TEST_PART | ||||
TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BRUTE_FORCE_PART1) { | |||||
matrix_mul::check_batched_matrix_mul( | |||||
dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(), | |||||
ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}, 1e-3, | |||||
matrix_mul::get_batched_matmul_args_mask(0)); | |||||
} | |||||
TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BRUTE_FORCE_PART2) { | |||||
matrix_mul::check_batched_matrix_mul( | |||||
dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(), | |||||
ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}, 1e-3, | |||||
matrix_mul::get_batched_matmul_args_mask(1)); | |||||
} | |||||
TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BRUTE_FORCE_PART3) { | |||||
matrix_mul::check_batched_matrix_mul( | |||||
dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(), | |||||
ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}, 1e-3, | |||||
matrix_mul::get_batched_matmul_args_mask(2)); | |||||
} | |||||
TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BRUTE_FORCE_PART4) { | |||||
matrix_mul::check_batched_matrix_mul( | |||||
dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(), | |||||
ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}, 1e-3, | |||||
matrix_mul::get_batched_matmul_args_mask(3)); | |||||
} | |||||
TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_PART1) { | TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_PART1) { | ||||
require_compute_capability(6, 0); | require_compute_capability(6, 0); | ||||
matrix_mul::check_batched_matrix_mul( | matrix_mul::check_batched_matrix_mul( | ||||
@@ -150,7 +178,8 @@ TEST_F(CUDA, BATCHED_MATRIX_MUL_INT8x8x32) { | |||||
TEST_F(CUDA, BATCHED_MATMUL_8x8x32_BENCHMARK) { | TEST_F(CUDA, BATCHED_MATMUL_8x8x32_BENCHMARK) { | ||||
require_compute_capability(6, 1); | require_compute_capability(6, 1); | ||||
auto run = [&](bool transA, bool transB, size_t m, size_t n, size_t k, | auto run = [&](bool transA, bool transB, size_t m, size_t n, size_t k, | ||||
const char* algo1, const char* algo2, size_t b = 128) { | |||||
const ExecutionPolicyAlgoName& algo1, | |||||
const ExecutionPolicyAlgoName& algo2, size_t b = 128) { | |||||
size_t RUNS = 10; | size_t RUNS = 10; | ||||
CUBenchmarker<BatchedMatrixMul> bencher1(handle_cuda()); | CUBenchmarker<BatchedMatrixMul> bencher1(handle_cuda()); | ||||
bencher1.set_display(false).set_times(RUNS); | bencher1.set_display(false).set_times(RUNS); | ||||
@@ -196,19 +225,20 @@ TEST_F(CUDA, BATCHED_MATMUL_8x8x32_BENCHMARK) { | |||||
printf("trA: %d, trB: %d, m: %ld, n: %ld, k: %ld, b: %ld, speedup: %s " | printf("trA: %d, trB: %d, m: %ld, n: %ld, k: %ld, b: %ld, speedup: %s " | ||||
"/ " | "/ " | ||||
"%s %.3f\n", | "%s %.3f\n", | ||||
transA, transB, m, n, k, b, algo1, algo2, flops1 / flops2); | |||||
transA, transB, m, n, k, b, algo1.name.c_str(), | |||||
algo2.name.c_str(), flops1 / flops2); | |||||
}; | }; | ||||
for (bool transA : {0, 1}) | for (bool transA : {0, 1}) | ||||
for (bool transB : {0, 1}) { | for (bool transB : {0, 1}) { | ||||
run(transA, transB, 128, 576, 128, "INT8x8x32", | run(transA, transB, 128, 576, 128, "INT8x8x32", | ||||
"BRUTE_FORCE-CUBLAS"); | |||||
ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}); | |||||
run(transA, transB, 256, 144, 256, "INT8x8x32", | run(transA, transB, 256, 144, 256, "INT8x8x32", | ||||
"BRUTE_FORCE-CUBLAS"); | |||||
ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}); | |||||
run(transA, transB, 512, 36, 512, "INT8x8x32", | run(transA, transB, 512, 36, 512, "INT8x8x32", | ||||
"BRUTE_FORCE-CUBLAS"); | |||||
ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}); | |||||
run(transA, transB, 1024, 8, 1024, "INT8x8x32", | run(transA, transB, 1024, 8, 1024, "INT8x8x32", | ||||
"BRUTE_FORCE-CUBLAS"); | |||||
ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}); | |||||
} | } | ||||
} | } | ||||
#endif | #endif | ||||