|
@@ -26,6 +26,7 @@ |
|
|
#include "megbrain/opr/tensor_manip.h" |
|
|
#include "megbrain/opr/tensor_manip.h" |
|
|
#include "megbrain/opr/imgproc.h" |
|
|
#include "megbrain/opr/imgproc.h" |
|
|
#include "megbrain/opr/nn_int.h" |
|
|
#include "megbrain/opr/nn_int.h" |
|
|
|
|
|
#include "megbrain/opr/tensor_gen.h" |
|
|
|
|
|
|
|
|
#include "megdnn/tensor_format.h" |
|
|
#include "megdnn/tensor_format.h" |
|
|
|
|
|
|
|
@@ -741,6 +742,19 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( |
|
|
return opr; |
|
|
return opr; |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
auto replace_lsp_opr = [](OperatorNodeBase* opr, |
|
|
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
|
|
mgb_assert(opr->same_type<opr::Linspace>()); |
|
|
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
|
|
auto& lsp_opr = opr->cast_final_safe<opr::Linspace>(); |
|
|
|
|
|
if (lsp_opr.output(0)->dtype() != dtype::Float16()) { |
|
|
|
|
|
auto cvt_var = |
|
|
|
|
|
opr::TypeCvt::make(lsp_opr.output(0), dtype::Float16(), {}); |
|
|
|
|
|
return cvt_var.node()->owner_opr(); |
|
|
|
|
|
} |
|
|
|
|
|
return opr; |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
auto replace_conv_opr = [use_f32_comp](OperatorNodeBase* opr, |
|
|
auto replace_conv_opr = [use_f32_comp](OperatorNodeBase* opr, |
|
|
const VarNodeArray& new_inp) { |
|
|
const VarNodeArray& new_inp) { |
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
@@ -778,6 +792,29 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( |
|
|
return new_matmul_opr.node()->owner_opr(); |
|
|
return new_matmul_opr.node()->owner_opr(); |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
auto replace_batched_matmul_opr = [use_f32_comp]( |
|
|
|
|
|
OperatorNodeBase* opr, |
|
|
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
|
|
auto& matmul_opr = opr->cast_final_safe<opr::BatchedMatrixMul>(); |
|
|
|
|
|
auto new_param = matmul_opr.param(); |
|
|
|
|
|
if (use_f32_comp) { |
|
|
|
|
|
new_param.compute_mode = |
|
|
|
|
|
megdnn::param::MatrixMul::ComputeMode::FLOAT32; |
|
|
|
|
|
} |
|
|
|
|
|
mgb_assert(new_inp[0]->dtype() == dtype::Float16(), |
|
|
|
|
|
"inp %s:%s, owner_opr:%s", new_inp[0]->dtype().name(), |
|
|
|
|
|
new_inp[0]->name().c_str(), |
|
|
|
|
|
new_inp[0]->owner_opr()->name().c_str()); |
|
|
|
|
|
mgb_assert(new_inp[1]->dtype() == dtype::Float16(), |
|
|
|
|
|
"inp %s:%s, owner_opr:%s", new_inp[1]->dtype().name(), |
|
|
|
|
|
new_inp[1]->name().c_str(), |
|
|
|
|
|
new_inp[1]->owner_opr()->name().c_str()); |
|
|
|
|
|
auto new_matmul_opr = opr::BatchedMatrixMul::make( |
|
|
|
|
|
new_inp[0], new_inp[1], new_param, matmul_opr.config()); |
|
|
|
|
|
return new_matmul_opr.node()->owner_opr(); |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
auto replace_reduce_opr = [use_f32_comp](OperatorNodeBase* opr, |
|
|
auto replace_reduce_opr = [use_f32_comp](OperatorNodeBase* opr, |
|
|
const VarNodeArray& new_inp) { |
|
|
const VarNodeArray& new_inp) { |
|
|
auto& reduce_opr = opr->cast_final_safe<opr::Reduce>(); |
|
|
auto& reduce_opr = opr->cast_final_safe<opr::Reduce>(); |
|
@@ -871,6 +908,7 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( |
|
|
ret->set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^ |
|
|
ret->set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^ |
|
|
VarReplaceCheckFlag::CHECK_DTYPE); |
|
|
VarReplaceCheckFlag::CHECK_DTYPE); |
|
|
auto&& replace_func = ret->m_opr_replace_func; |
|
|
auto&& replace_func = ret->m_opr_replace_func; |
|
|
|
|
|
replace_func[opr::Linspace::typeinfo()] = replace_lsp_opr; |
|
|
replace_func[opr::Host2DeviceCopy::typeinfo()] = replace_h2d_opr; |
|
|
replace_func[opr::Host2DeviceCopy::typeinfo()] = replace_h2d_opr; |
|
|
replace_func[opr::SharedDeviceTensor::typeinfo()] = replace_sdt_opr; |
|
|
replace_func[opr::SharedDeviceTensor::typeinfo()] = replace_sdt_opr; |
|
|
replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; |
|
|
replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; |
|
@@ -880,6 +918,8 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( |
|
|
replace_func[opr::TypeCvt::typeinfo()] = replace_cvt_opr; |
|
|
replace_func[opr::TypeCvt::typeinfo()] = replace_cvt_opr; |
|
|
replace_func[opr::WarpPerspective::typeinfo()] = replace_warp_opr; |
|
|
replace_func[opr::WarpPerspective::typeinfo()] = replace_warp_opr; |
|
|
replace_func[opr::Remap::typeinfo()] = replace_remap_opr; |
|
|
replace_func[opr::Remap::typeinfo()] = replace_remap_opr; |
|
|
|
|
|
replace_func[opr::BatchedMatrixMul::typeinfo()] = |
|
|
|
|
|
replace_batched_matmul_opr; |
|
|
return ret; |
|
|
return ret; |
|
|
#endif |
|
|
#endif |
|
|
} |
|
|
} |
|
|