|
@@ -183,6 +183,38 @@ ValueRefList convolution_rule(const OpDef& op, Span<ValueRef> inputs) { |
|
|
return imperative::apply(op, converted); |
|
|
return imperative::apply(op, converted); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// differ from Convolution, ConvolutionBackwardData is used in both |
|
|
|
|
|
// functional.conv_transpose2d and quantize.conv_transpose2d |
|
|
|
|
|
ValueRefList convolution_backward_rule(const OpDef& op, Span<ValueRef> inputs) { |
|
|
|
|
|
auto&& conv_op = const_cast<ConvolutionBackwardData&>( |
|
|
|
|
|
op.cast_final_safe<ConvolutionBackwardData>()); |
|
|
|
|
|
SmallVector<DType> dtypes = get_value_dtypes(inputs); |
|
|
|
|
|
|
|
|
|
|
|
if (is_quantized_dtype(dtypes[0]) && is_quantized_dtype(dtypes[1])) { |
|
|
|
|
|
return imperative::apply(op, inputs); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
mgb::DType target_dtype; |
|
|
|
|
|
if (DTypePromoteCfg::amp_dtype_autocast_enabled) { |
|
|
|
|
|
conv_op.compute_mode = ConvolutionBackwardData::ComputeMode::FLOAT32; |
|
|
|
|
|
target_dtype = DTypePromoteCfg::amp_low_prec_dtype; |
|
|
|
|
|
} else { |
|
|
|
|
|
target_dtype = get_promoted_dtype(dtypes); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
ValueRefList converted(inputs.size()); |
|
|
|
|
|
for (size_t i = 0; i < inputs.size(); ++i) { |
|
|
|
|
|
if (dtypes[i] != target_dtype) { |
|
|
|
|
|
converted[i] = imperative::apply( |
|
|
|
|
|
ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; |
|
|
|
|
|
} else { |
|
|
|
|
|
converted[i] = inputs[i]; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
return imperative::apply(op, converted); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
ValueRefList batch_norm_rule(const OpDef& op, Span<ValueRef> inputs) { |
|
|
ValueRefList batch_norm_rule(const OpDef& op, Span<ValueRef> inputs) { |
|
|
if (DTypePromoteCfg::amp_dtype_autocast_enabled) { |
|
|
if (DTypePromoteCfg::amp_dtype_autocast_enabled) { |
|
|
mgb_assert(inputs.size() > 0); |
|
|
mgb_assert(inputs.size() > 0); |
|
@@ -208,12 +240,32 @@ ValueRefList batch_norm_rule(const OpDef& op, Span<ValueRef> inputs) { |
|
|
return imperative::apply(op, inputs); |
|
|
return imperative::apply(op, inputs); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
ValueRefList convolution3d_rule(const OpDef& op, Span<ValueRef> inputs) { |
|
|
|
|
|
SmallVector<DType> dtypes = get_value_dtypes(inputs); |
|
|
|
|
|
mgb::DType target_dtype = get_promoted_dtype(dtypes); |
|
|
|
|
|
|
|
|
|
|
|
ValueRefList converted(inputs.size()); |
|
|
|
|
|
for (size_t i = 0; i < inputs.size(); ++i) { |
|
|
|
|
|
if (dtypes[i] != target_dtype) { |
|
|
|
|
|
converted[i] = imperative::apply( |
|
|
|
|
|
ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; |
|
|
|
|
|
} else { |
|
|
|
|
|
converted[i] = inputs[i]; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
return imperative::apply(op, converted); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
struct DTypePromoteRuleRegistry { |
|
|
struct DTypePromoteRuleRegistry { |
|
|
DTypePromoteRuleRegistry() { |
|
|
DTypePromoteRuleRegistry() { |
|
|
register_dtype_promote_rule<Elemwise>(elemwise_rule); |
|
|
register_dtype_promote_rule<Elemwise>(elemwise_rule); |
|
|
register_dtype_promote_rule<Reduce>(reduce_rule); |
|
|
register_dtype_promote_rule<Reduce>(reduce_rule); |
|
|
register_dtype_promote_rule<Convolution>(convolution_rule); |
|
|
register_dtype_promote_rule<Convolution>(convolution_rule); |
|
|
|
|
|
register_dtype_promote_rule<ConvolutionBackwardData>(convolution_backward_rule); |
|
|
register_dtype_promote_rule<BatchNorm>(batch_norm_rule); |
|
|
register_dtype_promote_rule<BatchNorm>(batch_norm_rule); |
|
|
|
|
|
register_dtype_promote_rule<Convolution3D>(convolution3d_rule); |
|
|
|
|
|
register_dtype_promote_rule<Convolution3DBackwardData>(convolution3d_rule); |
|
|
} |
|
|
} |
|
|
} register_helper; |
|
|
} register_helper; |
|
|
|
|
|
|
|
|