|
|
@@ -731,9 +731,10 @@ std::vector<TestArg> get_int8_chwn4_tensorcore_args(size_t kernel_size) { |
|
|
|
void check_conv_bias(DType src_dtype, DType filter_dtype, DType bias_dtype, |
|
|
|
DType dst_dtype, Handle* handle, const char* algo, |
|
|
|
param::ConvBias::Format format, |
|
|
|
const std::vector<TestArg>& args, bool fuse_z) { |
|
|
|
const std::vector<TestArg>& args, bool fuse_z, |
|
|
|
bool stable_test) { |
|
|
|
megdnn_assert(src_dtype.enumv() == filter_dtype.enumv()); |
|
|
|
Checker<ConvBiasForward> checker(handle); |
|
|
|
Checker<ConvBiasForward> checker(handle, !stable_test); |
|
|
|
if (algo) { |
|
|
|
checker.set_before_exec_callback( |
|
|
|
ConvBiasAlgoChecker<ConvBiasForward>(algo)); |
|
|
@@ -823,6 +824,10 @@ void check_conv_bias(DType src_dtype, DType filter_dtype, DType bias_dtype, |
|
|
|
.set_rng(1, rng.get()) |
|
|
|
.set_rng(2, bias_rng.get()) |
|
|
|
.set_rng(3, rng.get()); |
|
|
|
if (stable_test) { |
|
|
|
checker.set_stable_check(true); |
|
|
|
checker.set_no_naive_check(true); |
|
|
|
} |
|
|
|
if (args.empty()) { |
|
|
|
std::vector<TestArg> default_args; |
|
|
|
if (format == Format::NCHW4) { |
|
|
|