|
@@ -771,6 +771,29 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( |
|
|
return new_conv_opr.node()->owner_opr(); |
|
|
return new_conv_opr.node()->owner_opr(); |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
auto replace_deconv_opr = [use_f32_comp](OperatorNodeBase* opr, |
|
|
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
|
|
auto& deconv_opr = opr->cast_final_safe<opr::ConvolutionBackwardData>(); |
|
|
|
|
|
auto new_param = deconv_opr.param(); |
|
|
|
|
|
if (use_f32_comp) { |
|
|
|
|
|
new_param.compute_mode = |
|
|
|
|
|
megdnn::param::Convolution::ComputeMode::FLOAT32; |
|
|
|
|
|
} |
|
|
|
|
|
mgb_assert(new_inp[0]->dtype() == dtype::Float16(), |
|
|
|
|
|
"inp %s:%s, owner_opr:%s", new_inp[0]->dtype().name(), |
|
|
|
|
|
new_inp[0]->name().c_str(), |
|
|
|
|
|
new_inp[0]->owner_opr()->name().c_str()); |
|
|
|
|
|
mgb_assert(new_inp[1]->dtype() == dtype::Float16(), |
|
|
|
|
|
"inp %s:%s, owner_opr:%s", new_inp[1]->dtype().name(), |
|
|
|
|
|
new_inp[1]->name().c_str(), |
|
|
|
|
|
new_inp[1]->owner_opr()->name().c_str()); |
|
|
|
|
|
auto new_deconv_opr = opr::ConvolutionBackwardData::make( |
|
|
|
|
|
new_inp[0], new_inp[1], new_param, deconv_opr.execution_policy(), |
|
|
|
|
|
deconv_opr.config()); |
|
|
|
|
|
return new_deconv_opr.node()->owner_opr(); |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
auto replace_convbias_opr = [use_f32_comp](OperatorNodeBase* opr, |
|
|
auto replace_convbias_opr = [use_f32_comp](OperatorNodeBase* opr, |
|
|
const VarNodeArray& new_inp) { |
|
|
const VarNodeArray& new_inp) { |
|
|
auto& convbias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); |
|
|
auto& convbias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); |
|
@@ -941,6 +964,7 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( |
|
|
replace_func[opr::Host2DeviceCopy::typeinfo()] = replace_h2d_opr; |
|
|
replace_func[opr::Host2DeviceCopy::typeinfo()] = replace_h2d_opr; |
|
|
replace_func[opr::SharedDeviceTensor::typeinfo()] = replace_sdt_opr; |
|
|
replace_func[opr::SharedDeviceTensor::typeinfo()] = replace_sdt_opr; |
|
|
replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; |
|
|
replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; |
|
|
|
|
|
replace_func[opr::ConvolutionBackwardData::typeinfo()] = replace_deconv_opr; |
|
|
replace_func[opr::ConvBias::typeinfo()] = replace_convbias_opr; |
|
|
replace_func[opr::ConvBias::typeinfo()] = replace_convbias_opr; |
|
|
replace_func[opr::MatrixMul::typeinfo()] = replace_matmul_opr; |
|
|
replace_func[opr::MatrixMul::typeinfo()] = replace_matmul_opr; |
|
|
replace_func[opr::Reduce::typeinfo()] = replace_reduce_opr; |
|
|
replace_func[opr::Reduce::typeinfo()] = replace_reduce_opr; |
|
|