|
|
@@ -238,6 +238,25 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_CUDNN) { |
|
|
|
if (cuda::is_compute_capability_required(7, 0)) |
|
|
|
return; |
|
|
|
using namespace convolution; |
|
|
|
Checker<ConvolutionBackwardData> checker(handle_cuda()); |
|
|
|
checker.set_before_exec_callback(AlgoChecker<ConvolutionBackwardData>( |
|
|
|
"CUDNN_CONVOLUTION")); |
|
|
|
//! noncontiguous case |
|
|
|
{ |
|
|
|
param::Convolution param; |
|
|
|
param.pad_h = param.pad_w = 1; |
|
|
|
checker.set_param(param).execl(TensorLayoutArray{ |
|
|
|
{{16, 16, 3, 3}, {144, 9, 3, 1}, dtype::Float32()}, |
|
|
|
{{2, 16, 7, 7}, {1568, 49, 7, 1}, dtype::Float32()}, |
|
|
|
{{2, 16, 7, 7}, {1568, 49, 7, 1}, dtype::Float32()}, |
|
|
|
}); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_MATMUL) { |
|
|
|
using namespace convolution; |
|
|
|
std::vector<TestArg> args = get_args_cuda_conv_bwd_data(); |
|
|
@@ -265,6 +284,16 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_MATMUL) { |
|
|
|
.set_param(arg.param) |
|
|
|
.exec(TensorLayoutArray{filter, dst, src}); |
|
|
|
} |
|
|
|
//! noncontiguous case |
|
|
|
{ |
|
|
|
param::Convolution param; |
|
|
|
param.pad_h = param.pad_w = 1; |
|
|
|
checker.set_param(param).execl(TensorLayoutArray{ |
|
|
|
{{16, 16, 3, 3}, {144, 9, 3, 1}, dtype::Float32()}, |
|
|
|
{{2, 16, 7, 7}, {1568, 49, 7, 1}, dtype::Float32()}, |
|
|
|
{{2, 16, 7, 7}, {1568, 49, 7, 1}, dtype::Float32()}, |
|
|
|
}); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_INT8_NCHW4_DP4A) { |
|
|
@@ -355,6 +384,16 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_INT8_NCHW_DP4A) { |
|
|
|
} |
|
|
|
checker.set_rng(0, &rng).set_rng(1, &rng).set_param(arg.param).exec( |
|
|
|
TensorLayoutArray{filter, dst, src}); |
|
|
|
//! noncontiguous case |
|
|
|
{ |
|
|
|
param::Convolution param; |
|
|
|
param.pad_h = param.pad_w = 1; |
|
|
|
checker.set_param(param).execl(TensorLayoutArray{ |
|
|
|
{{16, 16, 3, 3}, {144, 9, 3, 1}, dtype::QuantizedS8{1.3f}}, |
|
|
|
{{2, 16, 7, 7}, {1568, 49, 7, 1}, dtype::QuantizedS8{1.2f}}, |
|
|
|
{{2, 16, 7, 7}, {1568, 49, 7, 1}, dtype::QuantizedS8{1.2f}} |
|
|
|
}); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|