GitOrigin-RevId: 585b2c045a
release-1.6
@@ -6,14 +6,15 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "./algos.h" | #include "./algos.h" | ||||
#include "src/cuda/utils.h" | |||||
#include "src/common/algo_base.h" | |||||
#include <cuda.h> | #include <cuda.h> | ||||
#include "src/common/algo_base.h" | |||||
#include "src/cuda/conv_bias/algo.h" | |||||
#include "src/cuda/conv_bias/opr_impl.h" | |||||
#include "src/cuda/utils.h" | |||||
#if CUDA_VERSION >= 10010 | #if CUDA_VERSION >= 10010 | ||||
#include <cublasLt.h> | #include <cublasLt.h> | ||||
#endif | #endif | ||||
@@ -52,8 +53,21 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||||
} | } | ||||
#endif | #endif | ||||
#endif | #endif | ||||
all_algos.push_back(&naive); | all_algos.push_back(&naive); | ||||
std::vector<cudnnConvolutionFwdAlgo_t> cudnn_conv_enum; | |||||
for (auto&& algo : CudnnAlgoPack::conv_fwd_algos()) { | |||||
cudnn_conv_enum.push_back(algo.first); | |||||
} | |||||
for (auto&& algo : cudnn_conv_enum) { | |||||
conv1x1.push_back(AlgoConv1X1CUDNN(algo)); | |||||
} | |||||
for (size_t i = 0; i < conv1x1.size(); ++i) { | |||||
all_algos.push_back(&conv1x1[i]); | |||||
} | |||||
for (auto&& algo : all_algos) { | for (auto&& algo : all_algos) { | ||||
m_all_algos_map.emplace(algo->info().desc, algo); | m_all_algos_map.emplace(algo->info().desc, algo); | ||||
} | } | ||||
@@ -11,19 +11,19 @@ | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include <cuda.h> | |||||
#include <memory> | |||||
#include <unordered_map> | |||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
#include "src/common/utils.h" | |||||
#include "src/cuda/matrix_mul/opr_impl.h" | |||||
#include "src/common/algo_base.h" | #include "src/common/algo_base.h" | ||||
#include "src/common/metahelper.h" | #include "src/common/metahelper.h" | ||||
#include <unordered_map> | |||||
#include <cuda.h> | |||||
#include <memory> | |||||
#include "src/common/utils.h" | |||||
#include "src/cuda/conv_bias/algo.h" | |||||
#include "src/cuda/conv_bias/opr_impl.h" | |||||
#include "src/cuda/matrix_mul/opr_impl.h" | |||||
#if CUDA_VERSION >= 10010 | #if CUDA_VERSION >= 10010 | ||||
#include <cublasLt.h> | #include <cublasLt.h> | ||||
#endif | #endif | ||||
namespace megdnn { | namespace megdnn { | ||||
namespace cuda { | namespace cuda { | ||||
@@ -42,6 +42,7 @@ public: | |||||
CUDA_CUBLASLT, | CUDA_CUBLASLT, | ||||
CUDA_NAIVE, | CUDA_NAIVE, | ||||
CUDA_BFLOAT16, | CUDA_BFLOAT16, | ||||
CUDA_CONV1X1_CUDNN, | |||||
#if CUDA_VERSION >= 9020 | #if CUDA_VERSION >= 9020 | ||||
CUDA_FLOAT32_SIMT, | CUDA_FLOAT32_SIMT, | ||||
CUDA_FLOAT32_SIMT_SPLIT_K, | CUDA_FLOAT32_SIMT_SPLIT_K, | ||||
@@ -189,6 +190,38 @@ private: | |||||
}; | }; | ||||
#endif | #endif | ||||
class MatrixMulForwardImpl::AlgoConv1X1CUDNN final : public AlgoBase { | |||||
public: | |||||
AlgoConv1X1CUDNN(cudnnConvolutionFwdAlgo_t algo_enum) { | |||||
m_impl = std::make_unique<ConvBiasForwardImpl::AlgoCUDNNConv>( | |||||
ConvBiasForwardImpl::AlgoCUDNNConv(algo_enum)); | |||||
std::string algoname(m_impl.get()->name()); | |||||
m_name = "MATMUL_CONV1X1:" + algoname; | |||||
} | |||||
bool is_available(const SizeArgs& args) const override; | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
const char* name() const override { return m_name.c_str(); } | |||||
void exec(const ExecArgs& args) const override; | |||||
AlgoAttribute attribute() const override { | |||||
auto ret = AlgoAttribute::DEFAULT; | |||||
#define cb(attr) \ | |||||
if (m_impl.get()->contain_attribute_all(attr)) { \ | |||||
ret |= attr; \ | |||||
} | |||||
MEGDNN_FOREACH_ALGO_ATTRIBUTE_INHERITABLE(cb) | |||||
#undef cb | |||||
if (m_impl.get()->contain_attribute_all(AlgoAttribute::REPRODUCIBLE)) { | |||||
ret |= AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
return ret; | |||||
} | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_CONV1X1_CUDNN) | |||||
private: | |||||
std::unique_ptr<ConvBiasForwardImpl::AlgoCUDNNConv> m_impl; | |||||
std::string m_name; | |||||
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||||
}; | |||||
#if CUDA_VERSION >= 9020 | #if CUDA_VERSION >= 9020 | ||||
class MatrixMulForwardImpl::AlgoCutlassMatrixMulBase : public AlgoBase { | class MatrixMulForwardImpl::AlgoCutlassMatrixMulBase : public AlgoBase { | ||||
public: | public: | ||||
@@ -244,7 +277,8 @@ public: | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT) | MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT) | ||||
std::string param() const override { | std::string param() const override { | ||||
std::string ret; | std::string ret; | ||||
// FIXME: algo param compatible with old version, to avoid fastrun cache error | |||||
// FIXME: algo param compatible with old version, to avoid fastrun cache | |||||
// error | |||||
struct AlgoParam_ { | struct AlgoParam_ { | ||||
int threadblock_m, threadblock_n, threadblock_k; | int threadblock_m, threadblock_n, threadblock_k; | ||||
int warp_m, warp_n, warp_k; | int warp_m, warp_n, warp_k; | ||||
@@ -272,7 +306,7 @@ public: | |||||
m_name{ssprintf("CUTLASS_FLOAT32_SIMT_SPLIT_K_%s", | m_name{ssprintf("CUTLASS_FLOAT32_SIMT_SPLIT_K_%s", | ||||
m_algo_param.to_string().c_str())} {} | m_algo_param.to_string().c_str())} {} | ||||
bool is_available(const SizeArgs& args) const override; | bool is_available(const SizeArgs& args) const override; | ||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
@@ -409,6 +443,7 @@ public: | |||||
std::vector<AlgoFloat16TensorOpSplitK> tensorop_float16_split_k; | std::vector<AlgoFloat16TensorOpSplitK> tensorop_float16_split_k; | ||||
#endif | #endif | ||||
#endif | #endif | ||||
std::vector<AlgoConv1X1CUDNN> conv1x1; | |||||
std::vector<AlgoBase*> all_algos; | std::vector<AlgoBase*> all_algos; | ||||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | ||||
@@ -0,0 +1,173 @@ | |||||
#include <cuda.h> | |||||
#include "./algos.h" | |||||
#include "src/cuda/conv_bias/algo.h" | |||||
#include "src/cuda/conv_bias/opr_impl.h" | |||||
#include "src/cuda/handle.h" | |||||
#include "src/cuda/relayout/opr_impl.h" | |||||
#include "src/cuda/transpose/opr_impl.h" | |||||
#include "src/cuda/utils.h" | |||||
using namespace megdnn; | |||||
using namespace cuda; | |||||
namespace { | |||||
std::unique_ptr<ConvBiasForward> prepare_conv_opr( | |||||
const MatrixMulForwardImpl::AlgoBase::SizeArgs& args) { | |||||
auto conv_bias_opr_ptr = | |||||
args.opr->handle()->create_operator<ConvBiasForward>(); | |||||
auto conv_param_computemode = | |||||
(args.opr->param().compute_mode == | |||||
param::MatrixMul::ComputeMode::DEFAULT) | |||||
? param::Convolution::ComputeMode::DEFAULT | |||||
: param::Convolution::ComputeMode::FLOAT32; | |||||
conv_bias_opr_ptr->param() = {param::ConvBias::NonlineMode::IDENTITY, | |||||
param::Convolution::Mode::CROSS_CORRELATION, | |||||
param::Convolution::Sparse::DENSE, | |||||
param::Convolution::Format::NCHW, | |||||
0, // pad_h | |||||
0, // pad_w | |||||
1, // stride_h | |||||
1, // stride_w | |||||
1, // dilate_h | |||||
1, // dilate_w | |||||
conv_param_computemode}; | |||||
return conv_bias_opr_ptr; | |||||
} | |||||
std::tuple<size_t, size_t, size_t> gen_matrixmul_shape( | |||||
const MatrixMulForwardImpl::AlgoBase::SizeArgs& args) { | |||||
size_t m, k, n; | |||||
if (!args.opr->param().transposeA) { | |||||
m = args.layout_a.shape[0]; | |||||
k = args.layout_a.shape[1]; | |||||
} else { | |||||
m = args.layout_a.shape[1]; | |||||
k = args.layout_a.shape[0]; | |||||
} | |||||
if (!args.opr->param().transposeB) { | |||||
megdnn_assert(k == args.layout_b.shape[0]); | |||||
n = args.layout_b.shape[1]; | |||||
} else { | |||||
megdnn_assert(k == args.layout_b.shape[1]); | |||||
n = args.layout_b.shape[0]; | |||||
} | |||||
return std::tuple<size_t, size_t, size_t> {m, k, n}; | |||||
} | |||||
} // namespace | |||||
bool MatrixMulForwardImpl::AlgoConv1X1CUDNN::is_available( | |||||
const SizeArgs& args) const { | |||||
if (!(args.layout_a.ndim == 2 && args.layout_b.ndim == 2 && | |||||
args.layout_c.ndim == 2)) | |||||
return false; | |||||
auto conv_opr_ptr = prepare_conv_opr(args); | |||||
size_t m, k, n; | |||||
std::tie(m, k, n) = gen_matrixmul_shape(args); | |||||
TensorLayout src_layout({1, k, 1, n}, args.layout_b.dtype); | |||||
TensorLayout filter_layout({m, k, 1, 1}, args.layout_a.dtype); | |||||
TensorLayout bias_layout(args.layout_a.dtype); | |||||
TensorLayout z_layout(args.layout_a.dtype); | |||||
TensorLayout dst_layout({1, m, 1, n}, args.layout_c.dtype); | |||||
ConvBiasForwardImpl::AlgoBase::SizeArgs conv_size_args( | |||||
static_cast<ConvBiasForwardImpl*>(conv_opr_ptr.get()), src_layout, | |||||
filter_layout, bias_layout, z_layout, dst_layout); | |||||
return m_impl->is_available(conv_size_args); | |||||
} | |||||
WorkspaceBundle MatrixMulForwardImpl::AlgoConv1X1CUDNN::get_workspace_bundle( | |||||
void* ptr, const SizeArgs& args) const { | |||||
SmallVector<size_t> sizes; | |||||
auto conv_opr_ptr = prepare_conv_opr(args); | |||||
size_t m, k, n; | |||||
std::tie(m, k, n) = gen_matrixmul_shape(args); | |||||
TensorLayout src_layout({1, k, 1, n}, args.layout_b.dtype); | |||||
TensorLayout filter_layout({m, k, 1, 1}, args.layout_a.dtype); | |||||
TensorLayout bias_layout(args.layout_a.dtype); | |||||
TensorLayout z_layout(args.layout_a.dtype); | |||||
TensorLayout dst_layout({1, m, 1, n}, args.layout_c.dtype); | |||||
ConvBiasForwardImpl::AlgoBase::SizeArgs conv_size_args( | |||||
static_cast<ConvBiasForwardImpl*>(conv_opr_ptr.get()), src_layout, | |||||
filter_layout, bias_layout, z_layout, dst_layout); | |||||
sizes.push_back(m_impl->get_workspace_in_bytes(conv_size_args)); | |||||
auto get_trans_layout = [](const TensorLayout& ly) { | |||||
size_t m = ly[0], n = ly[1]; | |||||
TensorLayout trans{{n, m}, ly.dtype}; | |||||
return trans; | |||||
}; | |||||
if (args.opr->param().transposeA) { | |||||
sizes.push_back(get_trans_layout(args.layout_a).span().dist_byte()); | |||||
} | |||||
if (args.opr->param().transposeB) { | |||||
sizes.push_back(get_trans_layout(args.layout_b).span().dist_byte()); | |||||
} | |||||
return {ptr, std::move(sizes)}; | |||||
} | |||||
size_t MatrixMulForwardImpl::AlgoConv1X1CUDNN::get_workspace_in_bytes( | |||||
const SizeArgs& args) const { | |||||
return get_workspace_bundle(nullptr, args).total_size_in_bytes(); | |||||
} | |||||
void MatrixMulForwardImpl::AlgoConv1X1CUDNN::exec(const ExecArgs& args) const { | |||||
SizeArgs size_args(args.opr, args.layout_a, args.layout_b, args.layout_c); | |||||
auto conv_opr_ptr = prepare_conv_opr(size_args); | |||||
size_t m, k, n; | |||||
std::tie(m, k, n) = gen_matrixmul_shape(size_args); | |||||
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, size_args); | |||||
auto A_dst_tensor = args.tensor_a; | |||||
auto B_dst_tensor = args.tensor_b; | |||||
if (args.opr->param().transposeA || args.opr->param().transposeB) { | |||||
auto trans = args.opr->handle()->create_operator<RelayoutForward>(); | |||||
auto trans_tensor = [&](size_t workspace_pos, | |||||
const TensorND& ori_tensor, | |||||
TensorND& dst_tensor) { | |||||
TensorLayout dst_layout( | |||||
{ori_tensor.layout.shape[1], ori_tensor.layout.shape[0]}, | |||||
ori_tensor.layout.dtype); | |||||
dst_tensor = TensorND(bundle.get(workspace_pos), dst_layout); | |||||
TensorND src_tensor(ori_tensor.raw_ptr, dst_layout); | |||||
src_tensor.layout.stride[0] = ori_tensor.layout.stride[1]; | |||||
src_tensor.layout.stride[1] = ori_tensor.layout.stride[0]; | |||||
trans->exec(src_tensor, dst_tensor, args.opr->handle()); | |||||
}; | |||||
if (args.opr->param().transposeA) { | |||||
trans_tensor(1, args.tensor_a, A_dst_tensor); | |||||
} | |||||
if (args.opr->param().transposeB) { | |||||
trans_tensor(bundle.nr_workspace() - 1, args.tensor_b, | |||||
B_dst_tensor); | |||||
} | |||||
} | |||||
TensorLayout src_layout({1, k, 1, n}, args.layout_b.dtype); | |||||
TensorLayout filter_layout({m, k, 1, 1}, args.layout_a.dtype); | |||||
TensorLayout dst_layout({1, m, 1, n}, args.layout_c.dtype); | |||||
TensorND src(B_dst_tensor.raw_ptr, src_layout); | |||||
TensorND filter(A_dst_tensor.raw_ptr, filter_layout); | |||||
TensorND z(nullptr, TensorLayout(src_layout.dtype)); | |||||
TensorND bias(nullptr, TensorLayout(src_layout.dtype)); | |||||
TensorND dst(args.tensor_c.raw_ptr, dst_layout); | |||||
ConvBiasForwardImpl::AlgoBase::ExecArgs conv_exec_args( | |||||
static_cast<ConvBiasForwardImpl*>(conv_opr_ptr.get()), src, filter, | |||||
bias, z, dst, bundle.get_workspace(0)); | |||||
m_impl->exec(conv_exec_args); | |||||
} |
@@ -31,6 +31,7 @@ public: | |||||
class AlgoBase; | class AlgoBase; | ||||
class AlgoCuBlas; | class AlgoCuBlas; | ||||
class AlgoConv1X1CUDNN; | |||||
#if CUDA_VERSION >= 10000 | #if CUDA_VERSION >= 10000 | ||||
class AlgoUInt4x4x32WMMA; | class AlgoUInt4x4x32WMMA; | ||||
#endif | #endif | ||||
@@ -8,7 +8,6 @@ | |||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "src/cuda/relayout/kern.cuh" | #include "src/cuda/relayout/kern.cuh" | ||||
#include "src/cuda/relayout/kern_contiguous.cuh" | #include "src/cuda/relayout/kern_contiguous.cuh" | ||||
#include "src/cuda/relayout/kern_transpose.cuh" | #include "src/cuda/relayout/kern_transpose.cuh" | ||||
@@ -14,7 +14,6 @@ | |||||
#include "megdnn/tensor_format.h" | #include "megdnn/tensor_format.h" | ||||
#include "test/common/tensor.h" | #include "test/common/tensor.h" | ||||
#include "test/common/timer.h" | #include "test/common/timer.h" | ||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace test; | using namespace test; | ||||
@@ -51,7 +50,7 @@ namespace { | |||||
++ it0; | ++ it0; | ||||
++ it1; | ++ it1; | ||||
} | } | ||||
float error_avg = error_sum / nr_elem; | float error_avg = error_sum / nr_elem; | ||||
if (error_avg > maxerr_avg) { | if (error_avg > maxerr_avg) { | ||||
return ::testing::AssertionFailure() | return ::testing::AssertionFailure() | ||||
@@ -6,13 +6,14 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "test/cuda/fixture.h" | #include "test/cuda/fixture.h" | ||||
#include "test/common/benchmarker.h" | |||||
#include "test/common/checker.h" | #include "test/common/checker.h" | ||||
#include "test/common/matrix_mul.h" | #include "test/common/matrix_mul.h" | ||||
#include "test/common/benchmarker.h" | |||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
#if defined(cuda_check) | #if defined(cuda_check) | ||||
@@ -62,7 +63,7 @@ TEST_F(CUDA, MATRIX_MUL_QUANTIZED4x4x32) { | |||||
checker.set_param(param); | checker.set_param(param); | ||||
checker.set_dtype(0, dtype::Quantized4Asymm(1.3f, (uint8_t)3)); | checker.set_dtype(0, dtype::Quantized4Asymm(1.3f, (uint8_t)3)); | ||||
checker.set_dtype(1, dtype::Quantized4Asymm(1.3f, (uint8_t)3)); | checker.set_dtype(1, dtype::Quantized4Asymm(1.3f, (uint8_t)3)); | ||||
checker.set_dtype(2, dtype::QuantizedS32(1.3f*1.3f)); | |||||
checker.set_dtype(2, dtype::QuantizedS32(1.3f * 1.3f)); | |||||
checker.exec({{256, 256}, {256, 256}, {256, 256}}); | checker.exec({{256, 256}, {256, 256}, {256, 256}}); | ||||
auto args = matrix_mul::get_matmul_args(); | auto args = matrix_mul::get_matmul_args(); | ||||
for (auto arg : args) { | for (auto arg : args) { | ||||
@@ -91,12 +92,12 @@ TEST_F(CUDA, BENCHMARK_MATRIX_MUL_QUANTIZED4x4x32) { | |||||
bencher.set_dtype(2, dtype::QuantizedS32(1.0f)); | bencher.set_dtype(2, dtype::QuantizedS32(1.0f)); | ||||
for (size_t m : {256, 1024, 4096, 10240, 40960}) { | for (size_t m : {256, 1024, 4096, 10240, 40960}) { | ||||
for (size_t n : {256, 1024, 4096}) { | for (size_t n : {256, 1024, 4096}) { | ||||
for (size_t k :{512, 1024, 2048}) { | |||||
for (size_t k : {512, 1024, 2048}) { | |||||
bencher.set_times(400); | bencher.set_times(400); | ||||
auto time_in_ms = bencher.exec({{m, k}, {n, k}, {m, n}}) / 400; | auto time_in_ms = bencher.exec({{m, k}, {n, k}, {m, n}}) / 400; | ||||
auto gflps = 2.0 * m * k * n / (time_in_ms * 1e-3) * 1e-12; | auto gflps = 2.0 * m * k * n / (time_in_ms * 1e-3) * 1e-12; | ||||
printf("m=%zu, k=%zu, n=%zu, time: %fms, perf: %f TFlops\n", | |||||
m, k, n, time_in_ms, gflps); | |||||
printf("m=%zu, k=%zu, n=%zu, time: %fms, perf: %f TFlops\n", m, | |||||
k, n, time_in_ms, gflps); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -217,9 +218,7 @@ TEST_F(CUDA, MATRIX_MUL_FLOAT_NAIVE) { | |||||
.set_dtype(0, stype) | .set_dtype(0, stype) | ||||
.set_dtype(1, stype) | .set_dtype(1, stype) | ||||
.set_dtype(2, dtype) | .set_dtype(2, dtype) | ||||
.set_epsilon(dtype == dtype::Float16() | |||||
? 5e-2 | |||||
: 5e-3) | |||||
.set_epsilon(dtype == dtype::Float16() ? 5e-2 : 5e-3) | |||||
.execs({A, B, {}}); | .execs({A, B, {}}); | ||||
} | } | ||||
} | } | ||||
@@ -232,6 +231,7 @@ TEST_F(CUDA, MATRIX_MUL) { | |||||
} | } | ||||
Checker<MatrixMul> checker(handle_cuda()); | Checker<MatrixMul> checker(handle_cuda()); | ||||
using Param = MatrixMul::Param; | using Param = MatrixMul::Param; | ||||
size_t m = 12, n = 16, k = 20; | size_t m = 12, n = 16, k = 20; | ||||
bool is_int_available = cuda::is_compute_capability_required(6, 1); | bool is_int_available = cuda::is_compute_capability_required(6, 1); | ||||
@@ -281,7 +281,7 @@ TEST_F(CUDA, MATRIX_MUL) { | |||||
// general tests | // general tests | ||||
auto args = matrix_mul::get_matmul_args(); | auto args = matrix_mul::get_matmul_args(); | ||||
for (auto arg: args) { | |||||
for (auto arg : args) { | |||||
auto m = arg.m, n = arg.n, k = arg.k; | auto m = arg.m, n = arg.n, k = arg.k; | ||||
auto mask = arg.mask; | auto mask = arg.mask; | ||||
Param param; | Param param; | ||||
@@ -320,8 +320,7 @@ TEST_F(CUDA, MATRIX_MUL) { | |||||
} | } | ||||
} | } | ||||
TEST_F(CUDA, MATRIX_MUL_CUBLASLT) | |||||
{ | |||||
TEST_F(CUDA, MATRIX_MUL_CUBLASLT) { | |||||
require_compute_capability(7, 5); | require_compute_capability(7, 5); | ||||
NormalRNG normal_rng; | NormalRNG normal_rng; | ||||
Checker<MatrixMul> checker(handle_cuda()); | Checker<MatrixMul> checker(handle_cuda()); | ||||
@@ -333,7 +332,7 @@ TEST_F(CUDA, MATRIX_MUL_CUBLASLT) | |||||
size_t m = 32, n = 32, k = 32; | size_t m = 32, n = 32, k = 32; | ||||
// test Int8 matmul | // test Int8 matmul | ||||
{ | { | ||||
DType dtype=dtype::Int32(); | |||||
DType dtype = dtype::Int32(); | |||||
Param param; | Param param; | ||||
param.transposeA = false; | param.transposeA = false; | ||||
param.transposeB = false; | param.transposeB = false; | ||||
@@ -341,16 +340,16 @@ TEST_F(CUDA, MATRIX_MUL_CUBLASLT) | |||||
TensorShape A, B; | TensorShape A, B; | ||||
A = TensorShape{m, k}; | A = TensorShape{m, k}; | ||||
B = TensorShape{k, n}; | B = TensorShape{k, n}; | ||||
checker.set_param(param). | |||||
set_dtype(0, stype). | |||||
set_dtype(1, stype). | |||||
set_dtype(2, dtype). | |||||
set_epsilon(dtype == dtype::Float16() ? 5e-2 : 5e-3). | |||||
execs({A, B, {}}); | |||||
checker.set_param(param) | |||||
.set_dtype(0, stype) | |||||
.set_dtype(1, stype) | |||||
.set_dtype(2, dtype) | |||||
.set_epsilon(dtype == dtype::Float16() ? 5e-2 : 5e-3) | |||||
.execs({A, B, {}}); | |||||
} | } | ||||
// test float-point matmul | // test float-point matmul | ||||
for (DType dtype: std::array<DType, 2>{ | |||||
{dtype::Float32(), dtype::Float16()}}) { | |||||
for (DType dtype : | |||||
std::array<DType, 2>{{dtype::Float32(), dtype::Float16()}}) { | |||||
for (unsigned mask = 0; mask < 4; ++mask) { | for (unsigned mask = 0; mask < 4; ++mask) { | ||||
Param param; | Param param; | ||||
param.transposeA = mask & 1; | param.transposeA = mask & 1; | ||||
@@ -365,17 +364,17 @@ TEST_F(CUDA, MATRIX_MUL_CUBLASLT) | |||||
B = TensorShape{n, k}; | B = TensorShape{n, k}; | ||||
else | else | ||||
B = TensorShape{k, n}; | B = TensorShape{k, n}; | ||||
checker.set_param(param). | |||||
set_dtype(0, stype). | |||||
set_dtype(1, stype). | |||||
set_dtype(2, dtype). | |||||
set_epsilon(dtype == dtype::Float16() ? 5e-2 : 8e-3). | |||||
execs({A, B, {}}); | |||||
checker.set_param(param) | |||||
.set_dtype(0, stype) | |||||
.set_dtype(1, stype) | |||||
.set_dtype(2, dtype) | |||||
.set_epsilon(dtype == dtype::Float16() ? 5e-2 : 8e-3) | |||||
.execs({A, B, {}}); | |||||
} | } | ||||
} | } | ||||
// general tests | // general tests | ||||
auto args = matrix_mul::get_matmul_args(); | auto args = matrix_mul::get_matmul_args(); | ||||
for (auto arg: args) { | |||||
for (auto arg : args) { | |||||
auto m = arg.m, n = arg.n, k = arg.k; | auto m = arg.m, n = arg.n, k = arg.k; | ||||
auto mask = arg.mask; | auto mask = arg.mask; | ||||
Param param; | Param param; | ||||
@@ -418,7 +417,7 @@ TEST_F(CUDA, MATRIX_MUL_CUBLASLT_SPECIAL_CASE) { | |||||
size_t m = 12, n = 16, k = 20; | size_t m = 12, n = 16, k = 20; | ||||
Checker<MatrixMul> checker(handle_cuda()); | Checker<MatrixMul> checker(handle_cuda()); | ||||
checker.set_before_exec_callback( | checker.set_before_exec_callback( | ||||
AlgoChecker<MatrixMulForward>("CUBLAS_LT")); | |||||
AlgoChecker<MatrixMulForward>("CUBLAS_LT")); | |||||
using Param = MatrixMul::Param; | using Param = MatrixMul::Param; | ||||
@@ -426,7 +425,7 @@ TEST_F(CUDA, MATRIX_MUL_CUBLASLT_SPECIAL_CASE) { | |||||
DType stype = dtype::Float32(); | DType stype = dtype::Float32(); | ||||
DType dtype = dtype::Float32(); | DType dtype = dtype::Float32(); | ||||
TensorShape A, B; | TensorShape A, B; | ||||
param.transposeA=param.transposeB=1; | |||||
param.transposeA = param.transposeB = 1; | |||||
if (param.transposeA) | if (param.transposeA) | ||||
A = TensorShape{k, m}; | A = TensorShape{k, m}; | ||||
else | else | ||||
@@ -435,43 +434,43 @@ TEST_F(CUDA, MATRIX_MUL_CUBLASLT_SPECIAL_CASE) { | |||||
B = TensorShape{n, k}; | B = TensorShape{n, k}; | ||||
else | else | ||||
B = TensorShape{k, n}; | B = TensorShape{k, n}; | ||||
checker.set_param(param). | |||||
set_dtype(0, stype). | |||||
set_dtype(1, stype). | |||||
set_dtype(2, dtype). | |||||
set_epsilon(dtype == dtype::Float16() ? 5e-1 : 5e-2). | |||||
execs({A, B, {}}); | |||||
checker.set_param(param) | |||||
.set_dtype(0, stype) | |||||
.set_dtype(1, stype) | |||||
.set_dtype(2, dtype) | |||||
.set_epsilon(dtype == dtype::Float16() ? 5e-1 : 5e-2) | |||||
.execs({A, B, {}}); | |||||
} | } | ||||
TEST_F(CUDA, MATRIX_MUL_CUBLASLT_INT8) { | TEST_F(CUDA, MATRIX_MUL_CUBLASLT_INT8) { | ||||
require_compute_capability(7, 5); | require_compute_capability(7, 5); | ||||
NormalRNG normal_rng; | NormalRNG normal_rng; | ||||
Checker<MatrixMul> checker(handle_cuda()); | Checker<MatrixMul> checker(handle_cuda()); | ||||
checker.set_rng(0, &normal_rng) | checker.set_rng(0, &normal_rng) | ||||
.set_rng(1, &normal_rng) | |||||
.set_before_exec_callback(AlgoChecker<MatrixMulForward>("CUBLAS_LT")); | |||||
.set_rng(1, &normal_rng) | |||||
.set_before_exec_callback( | |||||
AlgoChecker<MatrixMulForward>("CUBLAS_LT")); | |||||
using Param = MatrixMul::Param; | using Param = MatrixMul::Param; | ||||
//size_t m = 32, n = 32, k = 32; | |||||
// size_t m = 32, n = 32, k = 32; | |||||
// test Int8 matmul | // test Int8 matmul | ||||
for (size_t m=8; m<=64; m+=4) | |||||
for (size_t n=8; n<=64; n+=4) | |||||
for (size_t k=8; k<=64; k+=4) | |||||
{ | |||||
DType dtype=dtype::Int32(); | |||||
Param param; | |||||
param.transposeA = false; | |||||
param.transposeB = false; | |||||
DType stype = dtype == dtype::Int32() ? dtype::Int8() : dtype; | |||||
TensorShape A, B; | |||||
A = TensorShape{m, k}; | |||||
B = TensorShape{k, n}; | |||||
checker.set_param(param). | |||||
set_dtype(0, stype). | |||||
set_dtype(1, stype). | |||||
set_dtype(2, dtype). | |||||
set_epsilon(dtype == dtype::Float16() ? 5e-2 : 5e-3). | |||||
execs({A, B, {}}); | |||||
} | |||||
for (size_t m = 8; m <= 64; m += 4) | |||||
for (size_t n = 8; n <= 64; n += 4) | |||||
for (size_t k = 8; k <= 64; k += 4) { | |||||
DType dtype = dtype::Int32(); | |||||
Param param; | |||||
param.transposeA = false; | |||||
param.transposeB = false; | |||||
DType stype = dtype == dtype::Int32() ? dtype::Int8() : dtype; | |||||
TensorShape A, B; | |||||
A = TensorShape{m, k}; | |||||
B = TensorShape{k, n}; | |||||
checker.set_param(param) | |||||
.set_dtype(0, stype) | |||||
.set_dtype(1, stype) | |||||
.set_dtype(2, dtype) | |||||
.set_epsilon(dtype == dtype::Float16() ? 5e-2 : 5e-3) | |||||
.execs({A, B, {}}); | |||||
} | |||||
} | } | ||||
TEST_F(CUDA, MATRIX_MUL_CUBLASLT_F32) { | TEST_F(CUDA, MATRIX_MUL_CUBLASLT_F32) { | ||||
require_compute_capability(7, 5); | require_compute_capability(7, 5); | ||||
@@ -501,6 +500,94 @@ TEST_F(CUDA, MATRIX_MUL_CUBLASLT_F32) { | |||||
.set_dtype(2, dtype) | .set_dtype(2, dtype) | ||||
.execs({A, B, {}}); | .execs({A, B, {}}); | ||||
} | } | ||||
} // namespace test | |||||
} // namespace megdnn | |||||
TEST_F(CUDA, MATRIX_MUL_CUDNN_F32_uncont) { | |||||
Checker<MatrixMul> checker(handle_cuda()); | |||||
checker.set_before_exec_callback( | |||||
AlgoChecker<MatrixMulForward>("MATMUL_CONV1X1")); | |||||
using Param = MatrixMul::Param; | |||||
size_t m = 100, n = 100, k = 100; | |||||
Param param; | |||||
param.transposeA = 1; | |||||
param.transposeB = 1; | |||||
TensorLayout A{{m, k}, {128, 1}, dtype::Float32()}, | |||||
B{{k, n}, {128, 1}, dtype::Float32()}, C{{m, n}, dtype::Float32()}; | |||||
DType stype = dtype::Float32(); | |||||
DType dtype = dtype::Float32(); | |||||
checker.set_param(param) | |||||
.set_dtype(0, stype) | |||||
.set_dtype(1, stype) | |||||
.set_dtype(2, dtype) | |||||
.execl({A, B, {}}); | |||||
} | |||||
TEST_F(CUDA, MATRIX_MUL_CUDNN_F32) { | |||||
Checker<MatrixMul> checker(handle_cuda()); | |||||
checker.set_before_exec_callback( | |||||
AlgoChecker<MatrixMulForward>("MATMUL_CONV1X1")); | |||||
using Param = MatrixMul::Param; | |||||
for (size_t m = 8; m <= 64; m += 4) { | |||||
for (size_t n = 8; n <= 64; n += 4) { | |||||
for (size_t k = 8; k <= 64; k += 4) { | |||||
for (unsigned mask = 0; mask < 4; ++mask) { | |||||
Param param; | |||||
param.transposeA = mask & 1; | |||||
param.transposeB = mask & 2; | |||||
DType stype = dtype::Float32(); | |||||
DType dtype = dtype::Float32(); | |||||
TensorShape A, B; | |||||
if (param.transposeA) | |||||
A = TensorShape{k, m}; | |||||
else | |||||
A = TensorShape{m, k}; | |||||
if (param.transposeB) | |||||
B = TensorShape{n, k}; | |||||
else | |||||
B = TensorShape{k, n}; | |||||
checker.set_param(param) | |||||
.set_dtype(0, stype) | |||||
.set_dtype(1, stype) | |||||
.set_dtype(2, dtype) | |||||
.execs({A, B, {}}); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
TEST_F(CUDA, MATRIX_MUL_CUDNN_F16) { | |||||
Checker<MatrixMul> checker(handle_cuda()); | |||||
checker.set_before_exec_callback( | |||||
AlgoChecker<MatrixMulForward>("MATMUL_CONV1X1")); | |||||
using Param = MatrixMul::Param; | |||||
for (size_t m = 8; m <= 64; m += 4) { | |||||
for (size_t n = 8; n <= 64; n += 4) { | |||||
for (size_t k = 8; k <= 64; k += 4) { | |||||
for (unsigned mask = 0; mask < 4; ++mask) { | |||||
Param param; | |||||
param.transposeA = mask & 1; | |||||
param.transposeB = mask & 2; | |||||
DType stype = dtype::Float16(); | |||||
DType dtype = dtype::Float16(); | |||||
TensorShape A, B; | |||||
if (param.transposeA) | |||||
A = TensorShape{k, m}; | |||||
else | |||||
A = TensorShape{m, k}; | |||||
if (param.transposeB) | |||||
B = TensorShape{n, k}; | |||||
else | |||||
B = TensorShape{k, n}; | |||||
checker.set_param(param) | |||||
.set_dtype(0, stype) | |||||
.set_dtype(1, stype) | |||||
.set_dtype(2, dtype) | |||||
.set_epsilon(6e-2) | |||||
.execs({A, B, {}}); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} // namespace test | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |