|
|
@@ -218,17 +218,15 @@ ValueRefList convolution_backward_rule(const OpDef& op, Span<ValueRef> inputs) { |
|
|
|
ValueRefList batch_norm_rule(const OpDef& op, Span<ValueRef> inputs) { |
|
|
|
if (DTypePromoteCfg::amp_dtype_autocast_enabled) { |
|
|
|
mgb_assert(inputs.size() > 0); |
|
|
|
|
|
|
|
SmallVector<DType> dtypes = get_value_dtypes(inputs); |
|
|
|
ValueRefList converted(inputs.size()); |
|
|
|
converted[0] = imperative::apply( |
|
|
|
ApplyOp(*TypeCvt::make(dtype::Float16())), inputs[0])[0]; |
|
|
|
|
|
|
|
for (size_t i = 1; i < inputs.size(); ++i) { |
|
|
|
DType idtype = *(inputs[i].dtype()); |
|
|
|
if (idtype != DTypePromoteCfg::amp_high_prec_dtype) { |
|
|
|
for (size_t i = 0; i < inputs.size(); ++i) { |
|
|
|
mgb::DType target_dtype = i == 0 ? DTypePromoteCfg::amp_low_prec_dtype |
|
|
|
: DTypePromoteCfg::amp_high_prec_dtype; |
|
|
|
if (dtypes[i] != target_dtype) { |
|
|
|
converted[i] = imperative::apply( |
|
|
|
ApplyOp(*TypeCvt::make(DTypePromoteCfg::amp_high_prec_dtype)), |
|
|
|
inputs[i])[0]; |
|
|
|
ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; |
|
|
|
} else { |
|
|
|
converted[i] = inputs[i]; |
|
|
|
} |
|
|
|