|
|
@@ -6,7 +6,8 @@ |
|
|
|
* |
|
|
|
* Unless required by applicable law or agreed to in writing, |
|
|
|
* software distributed under the License is distributed on an |
|
|
|
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or |
|
|
|
* implied. |
|
|
|
*/ |
|
|
|
#include "megdnn/oprs/nn.h" |
|
|
|
|
|
|
@@ -37,7 +38,7 @@ std::vector<BenchArgs> get_resnet50_bench_args(size_t batch = 64) { |
|
|
|
args.emplace_back(BenchArgs{batch, 256, 56, 56, 32, 3, 1}); |
|
|
|
args.emplace_back(BenchArgs{batch, 256, 56, 56, 32, 3, 2}); |
|
|
|
args.emplace_back(BenchArgs{batch, 4, 256, 256, 32, 7, 2}); |
|
|
|
|
|
|
|
|
|
|
|
args.emplace_back(BenchArgs{batch, 256, 56, 56, 64, 1, 1}); |
|
|
|
args.emplace_back(BenchArgs{batch, 64, 56, 56, 64, 1, 1}); |
|
|
|
args.emplace_back(BenchArgs{batch, 64, 56, 56, 64, 3, 1}); |
|
|
@@ -614,11 +615,8 @@ TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_HSWISH) { |
|
|
|
param.stride_h = param.stride_w = 1; |
|
|
|
param.format = param::ConvBias::Format::CHWN4; |
|
|
|
param.nonlineMode = param::ConvBias::NonlineMode::H_SWISH; |
|
|
|
checker.set_param(param).execs({{4, 12, 12, 32, 4}, |
|
|
|
{4, 3, 3, 16, 4}, |
|
|
|
{4, 1, 1, 1, 4}, |
|
|
|
{}, |
|
|
|
{}}); |
|
|
|
checker.set_param(param).execs( |
|
|
|
{{4, 12, 12, 32, 4}, {4, 3, 3, 16, 4}, {4, 1, 1, 1, 4}, {}, {}}); |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_CHECK_BOUNDS) { |
|
|
@@ -1076,7 +1074,6 @@ TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_UNROLL_WIDTH_TENSORCORE_1x1_ALGO_2) { |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#if CUDA_VERSION >= 10020 |
|
|
|
/// \note: we only check several cases and block sizes in megdnn_test, the full |
|
|
|
/// testcases are written in cutlass repository |
|
|
@@ -1234,8 +1231,7 @@ TEST_F(CUDA, BENCHMARK_CUTLASS_CONV_BIAS_INT8_NCHW4) { |
|
|
|
handle_cuda(), get_resnet50_bench_args(64), |
|
|
|
dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f}, |
|
|
|
dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.0f}, |
|
|
|
"INT8_NCHW4_DOTPROD_IMPLICIT_GEMM", |
|
|
|
param::ConvBias::Format::NCHW4); |
|
|
|
"INT8_NCHW4_DOTPROD_IMPLICIT_GEMM", param::ConvBias::Format::NCHW4); |
|
|
|
} |
|
|
|
#endif |
|
|
|
} // namespace test |
|
|
|