|
|
@@ -355,6 +355,33 @@ inline FT get_inputs_format(Span<ValueRef>& inputs, const FormatTransformation& |
|
|
|
return format; |
|
|
|
} |
|
|
|
|
|
|
|
inline ValueRefList unify_nhwc_inputs( |
|
|
|
Span<ValueRef>& inputs, std::string scope, const FormatTransformation& t) { |
|
|
|
ValueRefList unified_inputs(inputs.size()); |
|
|
|
for (size_t i = 0; i < inputs.size(); ++i) { |
|
|
|
auto&& inp = inputs[i].cast(t.value_type()); |
|
|
|
if (inp.format() != FT::NHWC && |
|
|
|
inp.value().shape().cast<ShapeValue>().ndim == 4) { |
|
|
|
unified_inputs[i] = t.to(*t.as(inp, FT::NCHW), FT::NHWC, scope); |
|
|
|
} else { |
|
|
|
unified_inputs[i] = inputs[i]; |
|
|
|
} |
|
|
|
} |
|
|
|
return unified_inputs; |
|
|
|
} |
|
|
|
|
|
|
|
ValueRefList elemwise_rule( |
|
|
|
const Elemwise& op, Span<ValueRef>& inputs, const bool& auto_convert, |
|
|
|
const FormatTransformation& t) { |
|
|
|
FT format = get_inputs_format(inputs, t); |
|
|
|
if (format == FT::NHWC && auto_convert) { |
|
|
|
auto unified_inputs = unify_nhwc_inputs(inputs, op.scope(), t); |
|
|
|
return t.wrap_outputs( |
|
|
|
imperative::apply(op, t.unwrap_inputs(unified_inputs)), format); |
|
|
|
} |
|
|
|
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), format); |
|
|
|
} |
|
|
|
|
|
|
|
ValueRefList concat_rule( |
|
|
|
const Concat& op, Span<ValueRef>& inputs, const bool& auto_convert, |
|
|
|
const FormatTransformation& t) { |
|
|
@@ -362,6 +389,7 @@ ValueRefList concat_rule( |
|
|
|
if (!(format == FT::NHWC && auto_convert)) { |
|
|
|
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), format); |
|
|
|
} |
|
|
|
auto unified_inputs = unify_nhwc_inputs(inputs, op.scope(), t); |
|
|
|
// TODO: handle 5D NHWC Tensor from group conv |
|
|
|
auto axis = op.axis; |
|
|
|
if (axis == 2 || axis == 3) { |
|
|
@@ -372,7 +400,7 @@ ValueRefList concat_rule( |
|
|
|
return t.wrap_outputs( |
|
|
|
imperative::apply( |
|
|
|
*Concat::make(axis, op.comp_node, op.scope()), |
|
|
|
t.unwrap_inputs(inputs)), |
|
|
|
t.unwrap_inputs(unified_inputs)), |
|
|
|
format); |
|
|
|
} |
|
|
|
|
|
|
@@ -415,7 +443,6 @@ ValueRefList adaptive_pooling_rule( |
|
|
|
|
|
|
|
// clang-format off |
|
|
|
#define FOREACH_MULTI_INPS_NO_PARAM_OP(cb) \ |
|
|
|
cb(Elemwise) \ |
|
|
|
cb(CompiledOp) \ |
|
|
|
cb(SubgraphOp) |
|
|
|
|
|
|
@@ -501,6 +528,7 @@ struct FormatRuleRegistry { |
|
|
|
register_format_rule(subtensor_rule<IndexingMultiAxisVec>); |
|
|
|
register_format_rule(setsubtensor_rule<SetSubtensor>); |
|
|
|
register_format_rule(setsubtensor_rule<IndexingSetMultiAxisVec>); |
|
|
|
register_format_rule(elemwise_rule); |
|
|
|
register_format_rule(concat_rule); |
|
|
|
register_format_rule(batchnorm_rule); |
|
|
|
register_format_rule(adaptive_pooling_rule); |
|
|
@@ -515,7 +543,6 @@ struct FormatRuleRegistry { |
|
|
|
|
|
|
|
ValueRefList FormatTransformation::apply_transformation( |
|
|
|
const Operator& op, Span<ValueRef> inputs) { |
|
|
|
// mgb_log_warn("Format::apply_transformation %s", op.to_string().c_str()); |
|
|
|
if (auto* apply_op = op.as<ApplyOp>()) { |
|
|
|
// all inputs should be FormattedTensorValue |
|
|
|
auto iter = format_rules.find(apply_op->op().dyn_typeinfo()); |
|
|
|