Browse Source

refactor(imperative): move typecvt code of sereval ops to c++

GitOrigin-RevId: 4ffaa376c1
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
729242f9f8
2 changed files with 54 additions and 21 deletions
  1. +2
    -21
      imperative/python/megengine/functional/nn.py
  2. +52
    -0
      imperative/src/impl/transformations/dtype_promote.cpp

+ 2
- 21
imperative/python/megengine/functional/nn.py View File

@@ -320,12 +320,6 @@ def conv3d(
stride = _triple_nonzero(stride) stride = _triple_nonzero(stride)
dilate = _triple_nonzero(dilation) dilate = _triple_nonzero(dilation)


dtype = dtype_promotion(inp, weight)
if inp.dtype != dtype:
inp = inp.astype(dtype)
if weight.dtype != dtype:
weight = weight.astype(dtype)

sparse_type = "dense" if groups == 1 else "group" sparse_type = "dense" if groups == 1 else "group"
op = builtin.Convolution3D( op = builtin.Convolution3D(
pad_d=pad[D], pad_d=pad[D],
@@ -389,15 +383,6 @@ def conv_transpose2d(
conv_mode.lower() == "cross_correlation" conv_mode.lower() == "cross_correlation"
or conv_mode.name == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION"
) )
if amp._enabled:
compute_mode = "float32"
inp, weight, bias = cast_tensors(inp, weight, bias)
else:
dtype = dtype_promotion(inp, weight)
if inp.dtype != dtype:
inp = inp.astype(dtype)
if weight.dtype != dtype:
weight = weight.astype(dtype)


stride_h, stride_w = expand_hw(stride) stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding) pad_h, pad_w = expand_hw(padding)
@@ -418,6 +403,8 @@ def conv_transpose2d(
) )
(output,) = apply(op, weight, inp) (output,) = apply(op, weight, inp)
if bias is not None: if bias is not None:
if amp._enabled:
bias = cast_tensors(bias)
output += bias output += bias
return output return output


@@ -591,12 +578,6 @@ def conv_transpose3d(
stride = _triple_nonzero(stride) stride = _triple_nonzero(stride)
dilate = _triple_nonzero(dilation) dilate = _triple_nonzero(dilation)


dtype = dtype_promotion(inp, weight)
if inp.dtype != dtype:
inp = inp.astype(dtype)
if weight.dtype != dtype:
weight = weight.astype(dtype)

sparse_type = "dense" if groups == 1 else "group" sparse_type = "dense" if groups == 1 else "group"
op = builtin.Convolution3DBackwardData( op = builtin.Convolution3DBackwardData(
pad_d=pad[D], pad_d=pad[D],


+ 52
- 0
imperative/src/impl/transformations/dtype_promote.cpp View File

@@ -183,6 +183,38 @@ ValueRefList convolution_rule(const OpDef& op, Span<ValueRef> inputs) {
return imperative::apply(op, converted); return imperative::apply(op, converted);
} }


// differ from Convolution, ConvolutionBackwardData is used in both
// functional.conv_transpose2d and quantize.conv_transpose2d
ValueRefList convolution_backward_rule(const OpDef& op, Span<ValueRef> inputs) {
auto&& conv_op = const_cast<ConvolutionBackwardData&>(
op.cast_final_safe<ConvolutionBackwardData>());
SmallVector<DType> dtypes = get_value_dtypes(inputs);

if (is_quantized_dtype(dtypes[0]) && is_quantized_dtype(dtypes[1])) {
return imperative::apply(op, inputs);
}

mgb::DType target_dtype;
if (DTypePromoteCfg::amp_dtype_autocast_enabled) {
conv_op.compute_mode = ConvolutionBackwardData::ComputeMode::FLOAT32;
target_dtype = DTypePromoteCfg::amp_low_prec_dtype;
} else {
target_dtype = get_promoted_dtype(dtypes);
}

ValueRefList converted(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
if (dtypes[i] != target_dtype) {
converted[i] = imperative::apply(
ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0];
} else {
converted[i] = inputs[i];
}
}

return imperative::apply(op, converted);
}

ValueRefList batch_norm_rule(const OpDef& op, Span<ValueRef> inputs) { ValueRefList batch_norm_rule(const OpDef& op, Span<ValueRef> inputs) {
if (DTypePromoteCfg::amp_dtype_autocast_enabled) { if (DTypePromoteCfg::amp_dtype_autocast_enabled) {
mgb_assert(inputs.size() > 0); mgb_assert(inputs.size() > 0);
@@ -208,12 +240,32 @@ ValueRefList batch_norm_rule(const OpDef& op, Span<ValueRef> inputs) {
return imperative::apply(op, inputs); return imperative::apply(op, inputs);
} }


ValueRefList convolution3d_rule(const OpDef& op, Span<ValueRef> inputs) {
SmallVector<DType> dtypes = get_value_dtypes(inputs);
mgb::DType target_dtype = get_promoted_dtype(dtypes);

ValueRefList converted(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
if (dtypes[i] != target_dtype) {
converted[i] = imperative::apply(
ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0];
} else {
converted[i] = inputs[i];
}
}

return imperative::apply(op, converted);
}

struct DTypePromoteRuleRegistry { struct DTypePromoteRuleRegistry {
DTypePromoteRuleRegistry() { DTypePromoteRuleRegistry() {
register_dtype_promote_rule<Elemwise>(elemwise_rule); register_dtype_promote_rule<Elemwise>(elemwise_rule);
register_dtype_promote_rule<Reduce>(reduce_rule); register_dtype_promote_rule<Reduce>(reduce_rule);
register_dtype_promote_rule<Convolution>(convolution_rule); register_dtype_promote_rule<Convolution>(convolution_rule);
register_dtype_promote_rule<ConvolutionBackwardData>(convolution_backward_rule);
register_dtype_promote_rule<BatchNorm>(batch_norm_rule); register_dtype_promote_rule<BatchNorm>(batch_norm_rule);
register_dtype_promote_rule<Convolution3D>(convolution3d_rule);
register_dtype_promote_rule<Convolution3DBackwardData>(convolution3d_rule);
} }
} register_helper; } register_helper;




Loading…
Cancel
Save