diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index 17028a5e..646064af 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -555,6 +555,7 @@ std::unique_ptr EnableTensorCorePass:: replace_func[opr::Reshape::typeinfo()] = replace_inps_to_nchw4; replace_func[opr::GetVarShape::typeinfo()] = replace_inps_to_nchw4; replace_func[opr::Dimshuffle::typeinfo()] = replace_inps_to_nchw4; + replace_func[opr::BatchConvBias::typeinfo()] = replace_inps_to_nchw4; return ret; MIDOUT_E } diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index a14e8fdc..161d0a20 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -2244,6 +2244,67 @@ TEST(TestEnableTensorCore, Pooling) { MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt); } +TEST(TestEnableTensorCore, BatchConvBias) { + REQUIRE_GPU(1); + auto cn = CompNode::load("gpu0"); + cn.activate(); + auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop; + auto sm_ver = prop.major * 10 + prop.minor; + if (sm_ver < 75) { + printf("This testcast ignored due to insufficient cuda cap(got: %d, " + "expected: %d)\n", + sm_ver, 75); + return; + } + + HostTensorGenerator gen; + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkvar = [&](const char* name, const TensorShape& shp, const DType& dtype) { + return opr::TypeCvt::make( + opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name), dtype); + }; + auto mkcvar = [&](const char* name, const TensorShape& shp, const DType& dtype) { + return opr::TypeCvt::make( + opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name), + dtype); + }; + + auto inp = mkvar("inp", {32, 24, 24, 24, 4}, dtype::QuantizedS8(1.1f)), + flt = mkcvar("flt", {32, 96, 24, 1, 1, 4}, dtype::QuantizedS8(1.2f)), + bias = mkcvar("bias", {1, 24, 1, 1, 4}, dtype::QuantizedS32{1.1f * 1.2f}); + opr::BatchConvBias::Param param; + param.format = opr::BatchConvBias::Param::Format::NCHW4; + param.stride_h = param.stride_w = 1; + param.pad_h = param.pad_w = 0; + + auto y = opr::BatchConvBias::make( + inp, flt, bias, param, {}, OperatorNodeConfig{dtype::QuantizedS8{1.3f}}); + y = opr::TypeCvt::make(y, dtype::Float32()); + + SymbolVar y_opt; + SymbolVar y_no_tc; + { + auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_fuse_conv_bias_nonlinearity().enable_nchw32(); + unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); + } + ASSERT_EQ( + opr::BatchConvBias::Param::Format::NCHW4, + find_opr(y_opt).param().format); + { + auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_fuse_conv_bias_nonlinearity(); + unpack_vector(gopt::optimize_for_inference({y}, options), y_no_tc); + } + HostTensorND host_y, host_y_opt; + auto func = graph->compile( + {make_callback_copy(y_no_tc, host_y), + make_callback_copy(y_opt, host_y_opt)}); + func->execute(); + MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt); +} + TEST(TestGoptInference, EnableTensorCore) { REQUIRE_GPU(1); auto cn = CompNode::load("gpu0");