GitOrigin-RevId: 585b2c045a
release-1.6
@@ -6,14 +6,15 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "./algos.h" | |||
#include "src/cuda/utils.h" | |||
#include "src/common/algo_base.h" | |||
#include <cuda.h> | |||
#include "src/common/algo_base.h" | |||
#include "src/cuda/conv_bias/algo.h" | |||
#include "src/cuda/conv_bias/opr_impl.h" | |||
#include "src/cuda/utils.h" | |||
#if CUDA_VERSION >= 10010 | |||
#include <cublasLt.h> | |||
#endif | |||
@@ -52,8 +53,21 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||
} | |||
#endif | |||
#endif | |||
all_algos.push_back(&naive); | |||
std::vector<cudnnConvolutionFwdAlgo_t> cudnn_conv_enum; | |||
for (auto&& algo : CudnnAlgoPack::conv_fwd_algos()) { | |||
cudnn_conv_enum.push_back(algo.first); | |||
} | |||
for (auto&& algo : cudnn_conv_enum) { | |||
conv1x1.push_back(AlgoConv1X1CUDNN(algo)); | |||
} | |||
for (size_t i = 0; i < conv1x1.size(); ++i) { | |||
all_algos.push_back(&conv1x1[i]); | |||
} | |||
for (auto&& algo : all_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
@@ -11,19 +11,19 @@ | |||
*/ | |||
#pragma once | |||
#include <cuda.h> | |||
#include <memory> | |||
#include <unordered_map> | |||
#include "megdnn/oprs.h" | |||
#include "src/common/utils.h" | |||
#include "src/cuda/matrix_mul/opr_impl.h" | |||
#include "src/common/algo_base.h" | |||
#include "src/common/metahelper.h" | |||
#include <unordered_map> | |||
#include <cuda.h> | |||
#include <memory> | |||
#include "src/common/utils.h" | |||
#include "src/cuda/conv_bias/algo.h" | |||
#include "src/cuda/conv_bias/opr_impl.h" | |||
#include "src/cuda/matrix_mul/opr_impl.h" | |||
#if CUDA_VERSION >= 10010 | |||
#include <cublasLt.h> | |||
#endif | |||
namespace megdnn { | |||
namespace cuda { | |||
@@ -42,6 +42,7 @@ public: | |||
CUDA_CUBLASLT, | |||
CUDA_NAIVE, | |||
CUDA_BFLOAT16, | |||
CUDA_CONV1X1_CUDNN, | |||
#if CUDA_VERSION >= 9020 | |||
CUDA_FLOAT32_SIMT, | |||
CUDA_FLOAT32_SIMT_SPLIT_K, | |||
@@ -189,6 +190,38 @@ private: | |||
}; | |||
#endif | |||
class MatrixMulForwardImpl::AlgoConv1X1CUDNN final : public AlgoBase { | |||
public: | |||
AlgoConv1X1CUDNN(cudnnConvolutionFwdAlgo_t algo_enum) { | |||
m_impl = std::make_unique<ConvBiasForwardImpl::AlgoCUDNNConv>( | |||
ConvBiasForwardImpl::AlgoCUDNNConv(algo_enum)); | |||
std::string algoname(m_impl.get()->name()); | |||
m_name = "MATMUL_CONV1X1:" + algoname; | |||
} | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
const char* name() const override { return m_name.c_str(); } | |||
void exec(const ExecArgs& args) const override; | |||
AlgoAttribute attribute() const override { | |||
auto ret = AlgoAttribute::DEFAULT; | |||
#define cb(attr) \ | |||
if (m_impl.get()->contain_attribute_all(attr)) { \ | |||
ret |= attr; \ | |||
} | |||
MEGDNN_FOREACH_ALGO_ATTRIBUTE_INHERITABLE(cb) | |||
#undef cb | |||
if (m_impl.get()->contain_attribute_all(AlgoAttribute::REPRODUCIBLE)) { | |||
ret |= AlgoAttribute::REPRODUCIBLE; | |||
} | |||
return ret; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_CONV1X1_CUDNN) | |||
private: | |||
std::unique_ptr<ConvBiasForwardImpl::AlgoCUDNNConv> m_impl; | |||
std::string m_name; | |||
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||
}; | |||
#if CUDA_VERSION >= 9020 | |||
class MatrixMulForwardImpl::AlgoCutlassMatrixMulBase : public AlgoBase { | |||
public: | |||
@@ -244,7 +277,8 @@ public: | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT) | |||
std::string param() const override { | |||
std::string ret; | |||
// FIXME: algo param compatible with old version, to avoid fastrun cache error | |||
// FIXME: algo param compatible with old version, to avoid fastrun cache | |||
// error | |||
struct AlgoParam_ { | |||
int threadblock_m, threadblock_n, threadblock_k; | |||
int warp_m, warp_n, warp_k; | |||
@@ -272,7 +306,7 @@ public: | |||
m_name{ssprintf("CUTLASS_FLOAT32_SIMT_SPLIT_K_%s", | |||
m_algo_param.to_string().c_str())} {} | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
const char* name() const override { return m_name.c_str(); } | |||
AlgoAttribute attribute() const override { | |||
@@ -409,6 +443,7 @@ public: | |||
std::vector<AlgoFloat16TensorOpSplitK> tensorop_float16_split_k; | |||
#endif | |||
#endif | |||
std::vector<AlgoConv1X1CUDNN> conv1x1; | |||
std::vector<AlgoBase*> all_algos; | |||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
@@ -0,0 +1,173 @@ | |||
#include <cuda.h> | |||
#include "./algos.h" | |||
#include "src/cuda/conv_bias/algo.h" | |||
#include "src/cuda/conv_bias/opr_impl.h" | |||
#include "src/cuda/handle.h" | |||
#include "src/cuda/relayout/opr_impl.h" | |||
#include "src/cuda/transpose/opr_impl.h" | |||
#include "src/cuda/utils.h" | |||
using namespace megdnn; | |||
using namespace cuda; | |||
namespace { | |||
std::unique_ptr<ConvBiasForward> prepare_conv_opr( | |||
const MatrixMulForwardImpl::AlgoBase::SizeArgs& args) { | |||
auto conv_bias_opr_ptr = | |||
args.opr->handle()->create_operator<ConvBiasForward>(); | |||
auto conv_param_computemode = | |||
(args.opr->param().compute_mode == | |||
param::MatrixMul::ComputeMode::DEFAULT) | |||
? param::Convolution::ComputeMode::DEFAULT | |||
: param::Convolution::ComputeMode::FLOAT32; | |||
conv_bias_opr_ptr->param() = {param::ConvBias::NonlineMode::IDENTITY, | |||
param::Convolution::Mode::CROSS_CORRELATION, | |||
param::Convolution::Sparse::DENSE, | |||
param::Convolution::Format::NCHW, | |||
0, // pad_h | |||
0, // pad_w | |||
1, // stride_h | |||
1, // stride_w | |||
1, // dilate_h | |||
1, // dilate_w | |||
conv_param_computemode}; | |||
return conv_bias_opr_ptr; | |||
} | |||
std::tuple<size_t, size_t, size_t> gen_matrixmul_shape( | |||
const MatrixMulForwardImpl::AlgoBase::SizeArgs& args) { | |||
size_t m, k, n; | |||
if (!args.opr->param().transposeA) { | |||
m = args.layout_a.shape[0]; | |||
k = args.layout_a.shape[1]; | |||
} else { | |||
m = args.layout_a.shape[1]; | |||
k = args.layout_a.shape[0]; | |||
} | |||
if (!args.opr->param().transposeB) { | |||
megdnn_assert(k == args.layout_b.shape[0]); | |||
n = args.layout_b.shape[1]; | |||
} else { | |||
megdnn_assert(k == args.layout_b.shape[1]); | |||
n = args.layout_b.shape[0]; | |||
} | |||
return std::tuple<size_t, size_t, size_t> {m, k, n}; | |||
} | |||
} // namespace | |||
bool MatrixMulForwardImpl::AlgoConv1X1CUDNN::is_available( | |||
const SizeArgs& args) const { | |||
if (!(args.layout_a.ndim == 2 && args.layout_b.ndim == 2 && | |||
args.layout_c.ndim == 2)) | |||
return false; | |||
auto conv_opr_ptr = prepare_conv_opr(args); | |||
size_t m, k, n; | |||
std::tie(m, k, n) = gen_matrixmul_shape(args); | |||
TensorLayout src_layout({1, k, 1, n}, args.layout_b.dtype); | |||
TensorLayout filter_layout({m, k, 1, 1}, args.layout_a.dtype); | |||
TensorLayout bias_layout(args.layout_a.dtype); | |||
TensorLayout z_layout(args.layout_a.dtype); | |||
TensorLayout dst_layout({1, m, 1, n}, args.layout_c.dtype); | |||
ConvBiasForwardImpl::AlgoBase::SizeArgs conv_size_args( | |||
static_cast<ConvBiasForwardImpl*>(conv_opr_ptr.get()), src_layout, | |||
filter_layout, bias_layout, z_layout, dst_layout); | |||
return m_impl->is_available(conv_size_args); | |||
} | |||
WorkspaceBundle MatrixMulForwardImpl::AlgoConv1X1CUDNN::get_workspace_bundle( | |||
void* ptr, const SizeArgs& args) const { | |||
SmallVector<size_t> sizes; | |||
auto conv_opr_ptr = prepare_conv_opr(args); | |||
size_t m, k, n; | |||
std::tie(m, k, n) = gen_matrixmul_shape(args); | |||
TensorLayout src_layout({1, k, 1, n}, args.layout_b.dtype); | |||
TensorLayout filter_layout({m, k, 1, 1}, args.layout_a.dtype); | |||
TensorLayout bias_layout(args.layout_a.dtype); | |||
TensorLayout z_layout(args.layout_a.dtype); | |||
TensorLayout dst_layout({1, m, 1, n}, args.layout_c.dtype); | |||
ConvBiasForwardImpl::AlgoBase::SizeArgs conv_size_args( | |||
static_cast<ConvBiasForwardImpl*>(conv_opr_ptr.get()), src_layout, | |||
filter_layout, bias_layout, z_layout, dst_layout); | |||
sizes.push_back(m_impl->get_workspace_in_bytes(conv_size_args)); | |||
auto get_trans_layout = [](const TensorLayout& ly) { | |||
size_t m = ly[0], n = ly[1]; | |||
TensorLayout trans{{n, m}, ly.dtype}; | |||
return trans; | |||
}; | |||
if (args.opr->param().transposeA) { | |||
sizes.push_back(get_trans_layout(args.layout_a).span().dist_byte()); | |||
} | |||
if (args.opr->param().transposeB) { | |||
sizes.push_back(get_trans_layout(args.layout_b).span().dist_byte()); | |||
} | |||
return {ptr, std::move(sizes)}; | |||
} | |||
size_t MatrixMulForwardImpl::AlgoConv1X1CUDNN::get_workspace_in_bytes( | |||
const SizeArgs& args) const { | |||
return get_workspace_bundle(nullptr, args).total_size_in_bytes(); | |||
} | |||
void MatrixMulForwardImpl::AlgoConv1X1CUDNN::exec(const ExecArgs& args) const { | |||
SizeArgs size_args(args.opr, args.layout_a, args.layout_b, args.layout_c); | |||
auto conv_opr_ptr = prepare_conv_opr(size_args); | |||
size_t m, k, n; | |||
std::tie(m, k, n) = gen_matrixmul_shape(size_args); | |||
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, size_args); | |||
auto A_dst_tensor = args.tensor_a; | |||
auto B_dst_tensor = args.tensor_b; | |||
if (args.opr->param().transposeA || args.opr->param().transposeB) { | |||
auto trans = args.opr->handle()->create_operator<RelayoutForward>(); | |||
auto trans_tensor = [&](size_t workspace_pos, | |||
const TensorND& ori_tensor, | |||
TensorND& dst_tensor) { | |||
TensorLayout dst_layout( | |||
{ori_tensor.layout.shape[1], ori_tensor.layout.shape[0]}, | |||
ori_tensor.layout.dtype); | |||
dst_tensor = TensorND(bundle.get(workspace_pos), dst_layout); | |||
TensorND src_tensor(ori_tensor.raw_ptr, dst_layout); | |||
src_tensor.layout.stride[0] = ori_tensor.layout.stride[1]; | |||
src_tensor.layout.stride[1] = ori_tensor.layout.stride[0]; | |||
trans->exec(src_tensor, dst_tensor, args.opr->handle()); | |||
}; | |||
if (args.opr->param().transposeA) { | |||
trans_tensor(1, args.tensor_a, A_dst_tensor); | |||
} | |||
if (args.opr->param().transposeB) { | |||
trans_tensor(bundle.nr_workspace() - 1, args.tensor_b, | |||
B_dst_tensor); | |||
} | |||
} | |||
TensorLayout src_layout({1, k, 1, n}, args.layout_b.dtype); | |||
TensorLayout filter_layout({m, k, 1, 1}, args.layout_a.dtype); | |||
TensorLayout dst_layout({1, m, 1, n}, args.layout_c.dtype); | |||
TensorND src(B_dst_tensor.raw_ptr, src_layout); | |||
TensorND filter(A_dst_tensor.raw_ptr, filter_layout); | |||
TensorND z(nullptr, TensorLayout(src_layout.dtype)); | |||
TensorND bias(nullptr, TensorLayout(src_layout.dtype)); | |||
TensorND dst(args.tensor_c.raw_ptr, dst_layout); | |||
ConvBiasForwardImpl::AlgoBase::ExecArgs conv_exec_args( | |||
static_cast<ConvBiasForwardImpl*>(conv_opr_ptr.get()), src, filter, | |||
bias, z, dst, bundle.get_workspace(0)); | |||
m_impl->exec(conv_exec_args); | |||
} |
@@ -31,6 +31,7 @@ public: | |||
class AlgoBase; | |||
class AlgoCuBlas; | |||
class AlgoConv1X1CUDNN; | |||
#if CUDA_VERSION >= 10000 | |||
class AlgoUInt4x4x32WMMA; | |||
#endif | |||
@@ -8,7 +8,6 @@ | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/cuda/relayout/kern.cuh" | |||
#include "src/cuda/relayout/kern_contiguous.cuh" | |||
#include "src/cuda/relayout/kern_transpose.cuh" | |||
@@ -14,7 +14,6 @@ | |||
#include "megdnn/tensor_format.h" | |||
#include "test/common/tensor.h" | |||
#include "test/common/timer.h" | |||
using namespace megdnn; | |||
using namespace test; | |||
@@ -51,7 +50,7 @@ namespace { | |||
++ it0; | |||
++ it1; | |||
} | |||
float error_avg = error_sum / nr_elem; | |||
if (error_avg > maxerr_avg) { | |||
return ::testing::AssertionFailure() | |||
@@ -6,13 +6,14 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "test/cuda/fixture.h" | |||
#include "test/common/benchmarker.h" | |||
#include "test/common/checker.h" | |||
#include "test/common/matrix_mul.h" | |||
#include "test/common/benchmarker.h" | |||
#include "src/cuda/utils.h" | |||
#if defined(cuda_check) | |||
@@ -62,7 +63,7 @@ TEST_F(CUDA, MATRIX_MUL_QUANTIZED4x4x32) { | |||
checker.set_param(param); | |||
checker.set_dtype(0, dtype::Quantized4Asymm(1.3f, (uint8_t)3)); | |||
checker.set_dtype(1, dtype::Quantized4Asymm(1.3f, (uint8_t)3)); | |||
checker.set_dtype(2, dtype::QuantizedS32(1.3f*1.3f)); | |||
checker.set_dtype(2, dtype::QuantizedS32(1.3f * 1.3f)); | |||
checker.exec({{256, 256}, {256, 256}, {256, 256}}); | |||
auto args = matrix_mul::get_matmul_args(); | |||
for (auto arg : args) { | |||
@@ -91,12 +92,12 @@ TEST_F(CUDA, BENCHMARK_MATRIX_MUL_QUANTIZED4x4x32) { | |||
bencher.set_dtype(2, dtype::QuantizedS32(1.0f)); | |||
for (size_t m : {256, 1024, 4096, 10240, 40960}) { | |||
for (size_t n : {256, 1024, 4096}) { | |||
for (size_t k :{512, 1024, 2048}) { | |||
for (size_t k : {512, 1024, 2048}) { | |||
bencher.set_times(400); | |||
auto time_in_ms = bencher.exec({{m, k}, {n, k}, {m, n}}) / 400; | |||
auto gflps = 2.0 * m * k * n / (time_in_ms * 1e-3) * 1e-12; | |||
printf("m=%zu, k=%zu, n=%zu, time: %fms, perf: %f TFlops\n", | |||
m, k, n, time_in_ms, gflps); | |||
printf("m=%zu, k=%zu, n=%zu, time: %fms, perf: %f TFlops\n", m, | |||
k, n, time_in_ms, gflps); | |||
} | |||
} | |||
} | |||
@@ -217,9 +218,7 @@ TEST_F(CUDA, MATRIX_MUL_FLOAT_NAIVE) { | |||
.set_dtype(0, stype) | |||
.set_dtype(1, stype) | |||
.set_dtype(2, dtype) | |||
.set_epsilon(dtype == dtype::Float16() | |||
? 5e-2 | |||
: 5e-3) | |||
.set_epsilon(dtype == dtype::Float16() ? 5e-2 : 5e-3) | |||
.execs({A, B, {}}); | |||
} | |||
} | |||
@@ -232,6 +231,7 @@ TEST_F(CUDA, MATRIX_MUL) { | |||
} | |||
Checker<MatrixMul> checker(handle_cuda()); | |||
using Param = MatrixMul::Param; | |||
size_t m = 12, n = 16, k = 20; | |||
bool is_int_available = cuda::is_compute_capability_required(6, 1); | |||
@@ -281,7 +281,7 @@ TEST_F(CUDA, MATRIX_MUL) { | |||
// general tests | |||
auto args = matrix_mul::get_matmul_args(); | |||
for (auto arg: args) { | |||
for (auto arg : args) { | |||
auto m = arg.m, n = arg.n, k = arg.k; | |||
auto mask = arg.mask; | |||
Param param; | |||
@@ -320,8 +320,7 @@ TEST_F(CUDA, MATRIX_MUL) { | |||
} | |||
} | |||
TEST_F(CUDA, MATRIX_MUL_CUBLASLT) | |||
{ | |||
TEST_F(CUDA, MATRIX_MUL_CUBLASLT) { | |||
require_compute_capability(7, 5); | |||
NormalRNG normal_rng; | |||
Checker<MatrixMul> checker(handle_cuda()); | |||
@@ -333,7 +332,7 @@ TEST_F(CUDA, MATRIX_MUL_CUBLASLT) | |||
size_t m = 32, n = 32, k = 32; | |||
// test Int8 matmul | |||
{ | |||
DType dtype=dtype::Int32(); | |||
DType dtype = dtype::Int32(); | |||
Param param; | |||
param.transposeA = false; | |||
param.transposeB = false; | |||
@@ -341,16 +340,16 @@ TEST_F(CUDA, MATRIX_MUL_CUBLASLT) | |||
TensorShape A, B; | |||
A = TensorShape{m, k}; | |||
B = TensorShape{k, n}; | |||
checker.set_param(param). | |||
set_dtype(0, stype). | |||
set_dtype(1, stype). | |||
set_dtype(2, dtype). | |||
set_epsilon(dtype == dtype::Float16() ? 5e-2 : 5e-3). | |||
execs({A, B, {}}); | |||
checker.set_param(param) | |||
.set_dtype(0, stype) | |||
.set_dtype(1, stype) | |||
.set_dtype(2, dtype) | |||
.set_epsilon(dtype == dtype::Float16() ? 5e-2 : 5e-3) | |||
.execs({A, B, {}}); | |||
} | |||
// test float-point matmul | |||
for (DType dtype: std::array<DType, 2>{ | |||
{dtype::Float32(), dtype::Float16()}}) { | |||
for (DType dtype : | |||
std::array<DType, 2>{{dtype::Float32(), dtype::Float16()}}) { | |||
for (unsigned mask = 0; mask < 4; ++mask) { | |||
Param param; | |||
param.transposeA = mask & 1; | |||
@@ -365,17 +364,17 @@ TEST_F(CUDA, MATRIX_MUL_CUBLASLT) | |||
B = TensorShape{n, k}; | |||
else | |||
B = TensorShape{k, n}; | |||
checker.set_param(param). | |||
set_dtype(0, stype). | |||
set_dtype(1, stype). | |||
set_dtype(2, dtype). | |||
set_epsilon(dtype == dtype::Float16() ? 5e-2 : 8e-3). | |||
execs({A, B, {}}); | |||
checker.set_param(param) | |||
.set_dtype(0, stype) | |||
.set_dtype(1, stype) | |||
.set_dtype(2, dtype) | |||
.set_epsilon(dtype == dtype::Float16() ? 5e-2 : 8e-3) | |||
.execs({A, B, {}}); | |||
} | |||
} | |||
// general tests | |||
auto args = matrix_mul::get_matmul_args(); | |||
for (auto arg: args) { | |||
for (auto arg : args) { | |||
auto m = arg.m, n = arg.n, k = arg.k; | |||
auto mask = arg.mask; | |||
Param param; | |||
@@ -418,7 +417,7 @@ TEST_F(CUDA, MATRIX_MUL_CUBLASLT_SPECIAL_CASE) { | |||
size_t m = 12, n = 16, k = 20; | |||
Checker<MatrixMul> checker(handle_cuda()); | |||
checker.set_before_exec_callback( | |||
AlgoChecker<MatrixMulForward>("CUBLAS_LT")); | |||
AlgoChecker<MatrixMulForward>("CUBLAS_LT")); | |||
using Param = MatrixMul::Param; | |||
@@ -426,7 +425,7 @@ TEST_F(CUDA, MATRIX_MUL_CUBLASLT_SPECIAL_CASE) { | |||
DType stype = dtype::Float32(); | |||
DType dtype = dtype::Float32(); | |||
TensorShape A, B; | |||
param.transposeA=param.transposeB=1; | |||
param.transposeA = param.transposeB = 1; | |||
if (param.transposeA) | |||
A = TensorShape{k, m}; | |||
else | |||
@@ -435,43 +434,43 @@ TEST_F(CUDA, MATRIX_MUL_CUBLASLT_SPECIAL_CASE) { | |||
B = TensorShape{n, k}; | |||
else | |||
B = TensorShape{k, n}; | |||
checker.set_param(param). | |||
set_dtype(0, stype). | |||
set_dtype(1, stype). | |||
set_dtype(2, dtype). | |||
set_epsilon(dtype == dtype::Float16() ? 5e-1 : 5e-2). | |||
execs({A, B, {}}); | |||
checker.set_param(param) | |||
.set_dtype(0, stype) | |||
.set_dtype(1, stype) | |||
.set_dtype(2, dtype) | |||
.set_epsilon(dtype == dtype::Float16() ? 5e-1 : 5e-2) | |||
.execs({A, B, {}}); | |||
} | |||
TEST_F(CUDA, MATRIX_MUL_CUBLASLT_INT8) { | |||
require_compute_capability(7, 5); | |||
NormalRNG normal_rng; | |||
Checker<MatrixMul> checker(handle_cuda()); | |||
checker.set_rng(0, &normal_rng) | |||
.set_rng(1, &normal_rng) | |||
.set_before_exec_callback(AlgoChecker<MatrixMulForward>("CUBLAS_LT")); | |||
.set_rng(1, &normal_rng) | |||
.set_before_exec_callback( | |||
AlgoChecker<MatrixMulForward>("CUBLAS_LT")); | |||
using Param = MatrixMul::Param; | |||
//size_t m = 32, n = 32, k = 32; | |||
// size_t m = 32, n = 32, k = 32; | |||
// test Int8 matmul | |||
for (size_t m=8; m<=64; m+=4) | |||
for (size_t n=8; n<=64; n+=4) | |||
for (size_t k=8; k<=64; k+=4) | |||
{ | |||
DType dtype=dtype::Int32(); | |||
Param param; | |||
param.transposeA = false; | |||
param.transposeB = false; | |||
DType stype = dtype == dtype::Int32() ? dtype::Int8() : dtype; | |||
TensorShape A, B; | |||
A = TensorShape{m, k}; | |||
B = TensorShape{k, n}; | |||
checker.set_param(param). | |||
set_dtype(0, stype). | |||
set_dtype(1, stype). | |||
set_dtype(2, dtype). | |||
set_epsilon(dtype == dtype::Float16() ? 5e-2 : 5e-3). | |||
execs({A, B, {}}); | |||
} | |||
for (size_t m = 8; m <= 64; m += 4) | |||
for (size_t n = 8; n <= 64; n += 4) | |||
for (size_t k = 8; k <= 64; k += 4) { | |||
DType dtype = dtype::Int32(); | |||
Param param; | |||
param.transposeA = false; | |||
param.transposeB = false; | |||
DType stype = dtype == dtype::Int32() ? dtype::Int8() : dtype; | |||
TensorShape A, B; | |||
A = TensorShape{m, k}; | |||
B = TensorShape{k, n}; | |||
checker.set_param(param) | |||
.set_dtype(0, stype) | |||
.set_dtype(1, stype) | |||
.set_dtype(2, dtype) | |||
.set_epsilon(dtype == dtype::Float16() ? 5e-2 : 5e-3) | |||
.execs({A, B, {}}); | |||
} | |||
} | |||
TEST_F(CUDA, MATRIX_MUL_CUBLASLT_F32) { | |||
require_compute_capability(7, 5); | |||
@@ -501,6 +500,94 @@ TEST_F(CUDA, MATRIX_MUL_CUBLASLT_F32) { | |||
.set_dtype(2, dtype) | |||
.execs({A, B, {}}); | |||
} | |||
} // namespace test | |||
} // namespace megdnn | |||
TEST_F(CUDA, MATRIX_MUL_CUDNN_F32_uncont) { | |||
Checker<MatrixMul> checker(handle_cuda()); | |||
checker.set_before_exec_callback( | |||
AlgoChecker<MatrixMulForward>("MATMUL_CONV1X1")); | |||
using Param = MatrixMul::Param; | |||
size_t m = 100, n = 100, k = 100; | |||
Param param; | |||
param.transposeA = 1; | |||
param.transposeB = 1; | |||
TensorLayout A{{m, k}, {128, 1}, dtype::Float32()}, | |||
B{{k, n}, {128, 1}, dtype::Float32()}, C{{m, n}, dtype::Float32()}; | |||
DType stype = dtype::Float32(); | |||
DType dtype = dtype::Float32(); | |||
checker.set_param(param) | |||
.set_dtype(0, stype) | |||
.set_dtype(1, stype) | |||
.set_dtype(2, dtype) | |||
.execl({A, B, {}}); | |||
} | |||
TEST_F(CUDA, MATRIX_MUL_CUDNN_F32) { | |||
Checker<MatrixMul> checker(handle_cuda()); | |||
checker.set_before_exec_callback( | |||
AlgoChecker<MatrixMulForward>("MATMUL_CONV1X1")); | |||
using Param = MatrixMul::Param; | |||
for (size_t m = 8; m <= 64; m += 4) { | |||
for (size_t n = 8; n <= 64; n += 4) { | |||
for (size_t k = 8; k <= 64; k += 4) { | |||
for (unsigned mask = 0; mask < 4; ++mask) { | |||
Param param; | |||
param.transposeA = mask & 1; | |||
param.transposeB = mask & 2; | |||
DType stype = dtype::Float32(); | |||
DType dtype = dtype::Float32(); | |||
TensorShape A, B; | |||
if (param.transposeA) | |||
A = TensorShape{k, m}; | |||
else | |||
A = TensorShape{m, k}; | |||
if (param.transposeB) | |||
B = TensorShape{n, k}; | |||
else | |||
B = TensorShape{k, n}; | |||
checker.set_param(param) | |||
.set_dtype(0, stype) | |||
.set_dtype(1, stype) | |||
.set_dtype(2, dtype) | |||
.execs({A, B, {}}); | |||
} | |||
} | |||
} | |||
} | |||
} | |||
TEST_F(CUDA, MATRIX_MUL_CUDNN_F16) { | |||
Checker<MatrixMul> checker(handle_cuda()); | |||
checker.set_before_exec_callback( | |||
AlgoChecker<MatrixMulForward>("MATMUL_CONV1X1")); | |||
using Param = MatrixMul::Param; | |||
for (size_t m = 8; m <= 64; m += 4) { | |||
for (size_t n = 8; n <= 64; n += 4) { | |||
for (size_t k = 8; k <= 64; k += 4) { | |||
for (unsigned mask = 0; mask < 4; ++mask) { | |||
Param param; | |||
param.transposeA = mask & 1; | |||
param.transposeB = mask & 2; | |||
DType stype = dtype::Float16(); | |||
DType dtype = dtype::Float16(); | |||
TensorShape A, B; | |||
if (param.transposeA) | |||
A = TensorShape{k, m}; | |||
else | |||
A = TensorShape{m, k}; | |||
if (param.transposeB) | |||
B = TensorShape{n, k}; | |||
else | |||
B = TensorShape{k, n}; | |||
checker.set_param(param) | |||
.set_dtype(0, stype) | |||
.set_dtype(1, stype) | |||
.set_dtype(2, dtype) | |||
.set_epsilon(6e-2) | |||
.execs({A, B, {}}); | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} // namespace test | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |