Browse Source

feat(imperative/amp): add auto dimshuffle for elemwise and concat

GitOrigin-RevId: 6e3df4e064
release-1.10
Megvii Engine Team 3 years ago
parent
commit
38a9aa9faf
2 changed files with 36 additions and 5 deletions
  1. +6
    -2
      imperative/python/test/unit/core/test_formatted_tensor.py
  2. +30
    -3
      imperative/src/impl/transformations/format.cpp

+ 6
- 2
imperative/python/test/unit/core/test_formatted_tensor.py View File

@@ -193,7 +193,10 @@ def test_typecvt(is_symbolic):
@pytest.mark.parametrize("is_symbolic", [None])
def test_elemwise(is_symbolic):
def elemwise(x):
return (x * 2 + x / 2).numpy()
tmp = F.ones((1, 2, 3, 4))
oup = x * tmp + x / 2
assert oup.format == x.format
return oup.numpy()

data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, elemwise, is_symbolic)
@@ -202,7 +205,8 @@ def test_elemwise(is_symbolic):
@pytest.mark.parametrize("is_symbolic", [None])
def test_concat(is_symbolic):
def func(x):
rst = F.concat([x / 2, x * 2], axis=1)
tmp = F.ones((1, 2, 3, 4))
rst = F.concat([x / 2, tmp], axis=1)
assert rst.format == x.format
return rst.numpy()



+ 30
- 3
imperative/src/impl/transformations/format.cpp View File

@@ -355,6 +355,33 @@ inline FT get_inputs_format(Span<ValueRef>& inputs, const FormatTransformation&
return format;
}

inline ValueRefList unify_nhwc_inputs(
Span<ValueRef>& inputs, std::string scope, const FormatTransformation& t) {
ValueRefList unified_inputs(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
auto&& inp = inputs[i].cast(t.value_type());
if (inp.format() != FT::NHWC &&
inp.value().shape().cast<ShapeValue>().ndim == 4) {
unified_inputs[i] = t.to(*t.as(inp, FT::NCHW), FT::NHWC, scope);
} else {
unified_inputs[i] = inputs[i];
}
}
return unified_inputs;
}

ValueRefList elemwise_rule(
const Elemwise& op, Span<ValueRef>& inputs, const bool& auto_convert,
const FormatTransformation& t) {
FT format = get_inputs_format(inputs, t);
if (format == FT::NHWC && auto_convert) {
auto unified_inputs = unify_nhwc_inputs(inputs, op.scope(), t);
return t.wrap_outputs(
imperative::apply(op, t.unwrap_inputs(unified_inputs)), format);
}
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), format);
}

ValueRefList concat_rule(
const Concat& op, Span<ValueRef>& inputs, const bool& auto_convert,
const FormatTransformation& t) {
@@ -362,6 +389,7 @@ ValueRefList concat_rule(
if (!(format == FT::NHWC && auto_convert)) {
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), format);
}
auto unified_inputs = unify_nhwc_inputs(inputs, op.scope(), t);
// TODO: handle 5D NHWC Tensor from group conv
auto axis = op.axis;
if (axis == 2 || axis == 3) {
@@ -372,7 +400,7 @@ ValueRefList concat_rule(
return t.wrap_outputs(
imperative::apply(
*Concat::make(axis, op.comp_node, op.scope()),
t.unwrap_inputs(inputs)),
t.unwrap_inputs(unified_inputs)),
format);
}

@@ -415,7 +443,6 @@ ValueRefList adaptive_pooling_rule(

// clang-format off
#define FOREACH_MULTI_INPS_NO_PARAM_OP(cb) \
cb(Elemwise) \
cb(CompiledOp) \
cb(SubgraphOp)

@@ -501,6 +528,7 @@ struct FormatRuleRegistry {
register_format_rule(subtensor_rule<IndexingMultiAxisVec>);
register_format_rule(setsubtensor_rule<SetSubtensor>);
register_format_rule(setsubtensor_rule<IndexingSetMultiAxisVec>);
register_format_rule(elemwise_rule);
register_format_rule(concat_rule);
register_format_rule(batchnorm_rule);
register_format_rule(adaptive_pooling_rule);
@@ -515,7 +543,6 @@ struct FormatRuleRegistry {

ValueRefList FormatTransformation::apply_transformation(
const Operator& op, Span<ValueRef> inputs) {
// mgb_log_warn("Format::apply_transformation %s", op.to_string().c_str());
if (auto* apply_op = op.as<ApplyOp>()) {
// all inputs should be FormattedTensorValue
auto iter = format_rules.find(apply_op->op().dyn_typeinfo());


Loading…
Cancel
Save