diff --git a/imperative/python/megengine/amp/convert_format.py b/imperative/python/megengine/amp/convert_format.py index 28af3640..4ea82e23 100644 --- a/imperative/python/megengine/amp/convert_format.py +++ b/imperative/python/megengine/amp/convert_format.py @@ -1,10 +1,3 @@ -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from copy import deepcopy from .. import functional as F diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index b8c384d0..c6130639 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -592,7 +592,6 @@ def matmul( transpose_a=False, transpose_b=False, compute_mode="default", - format="default", ) -> Tensor: r"""Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``. @@ -625,7 +624,7 @@ def matmul( array([[10., 13.], [28., 40.]], dtype=float32) """ - return _matmul(inp1, inp2, transpose_a, transpose_b, compute_mode, format) + return _matmul(inp1, inp2, transpose_a, transpose_b, compute_mode) def dot(inp1: Tensor, inp2: Tensor) -> Tensor: diff --git a/imperative/python/test/unit/amp/test_convert_format.py b/imperative/python/test/unit/amp/test_convert_format.py index d9749a14..d78c231e 100644 --- a/imperative/python/test/unit/amp/test_convert_format.py +++ b/imperative/python/test/unit/amp/test_convert_format.py @@ -1,10 +1,3 @@ -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import numpy as np import pytest diff --git a/imperative/src/impl/transformations/format.cpp b/imperative/src/impl/transformations/format.cpp index 83b7549e..287cd65c 100644 --- a/imperative/src/impl/transformations/format.cpp +++ b/imperative/src/impl/transformations/format.cpp @@ -23,24 +23,42 @@ TypedValueRef FormatTransformation::to( if (format == target) return as(tensor, target); + auto&& shape = tensor.value().shape().cast(); if (format == FT::NHWC && (target == FT::NCHW || target == FT::DEFAULT)) { // FIXME(czh): temporary fast path for group conv 5D weight. - if (tensor.value().shape().cast().ndim == 5) { + if (shape.ndim == 5) { pattern = {0, 1, 4, 2, 3}; - } else { + } else if (shape.ndim == 4) { pattern = {0, 3, 1, 2}; + } else { + mgb_throw( + MegBrainError, + "Unsupport format conversion for tensor %s(shape=%s) from %s to %s", + tensor.to_string().c_str(), shape.to_string().c_str(), + format.to_string().c_str(), Format(target).to_string().c_str()); } } else if ((format == FT::NCHW || format == FT::DEFAULT) && target == FT::NHWC) { - if (tensor.value().shape().cast().ndim == 5) { + if (shape.ndim == 5) { pattern = {0, 1, 3, 4, 2}; - } else { + } else if (shape.ndim == 4) { pattern = {0, 2, 3, 1}; + } else { + mgb_throw( + MegBrainError, + "Unsupport format conversion for tensor %s(shape=%s) from %s to %s", + tensor.to_string().c_str(), shape.to_string().c_str(), + format.to_string().c_str(), Format(target).to_string().c_str()); } } else { mgb_throw( - MegBrainError, "Unsupport format conversion from %s to %s", + MegBrainError, + "Unsupport format conversion for tensor %s(shape=%s) from %s to %s", + tensor.to_string().c_str(), shape.to_string().c_str(), format.to_string().c_str(), Format(target).to_string().c_str()); } + mgb_log_debug( + "Change tensor %s from %s to %s", tensor.to_string().c_str(), + format.to_string().c_str(), Format(target).to_string().c_str()); auto output = imperative::apply(*Dimshuffle::make(pattern, scope), {tensor.value()})[0]; return m_value_type.make(output, target); @@ -380,9 +398,7 @@ inline ValueRefList unify_inputs_format( ValueRefList unified_inputs(inputs.size()); for (size_t i = 0; i < inputs.size(); ++i) { auto&& inp = inputs[i].cast(t.value_type()); - if (inp.format() != dst_fmt && - (inp.value().shape().cast().ndim == 4 || - inp.value().shape().cast().ndim == 5)) { + if (inp.format() != dst_fmt) { unified_inputs[i] = t.to(inp, dst_fmt, scope); } else { unified_inputs[i] = inputs[i]; @@ -396,7 +412,16 @@ ValueRefList elemwise_rule( const FormatTransformation& t) { FT format = get_inputs_format(inputs, t); if (format == FT::NHWC && auto_convert) { - auto unified_inputs = unify_inputs_format(inputs, FT::NHWC, op.scope(), t); + ValueRefList unified_inputs(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + auto&& inp = inputs[i].cast(t.value_type()); + if (inp.format() != FT::NHWC && inp.value().is_scalar()) { + unified_inputs[i] = t.value_type().make(inp.value(), FT::NHWC); + } else { + unified_inputs[i] = inputs[i]; + } + } + unified_inputs = unify_inputs_format(unified_inputs, FT::NHWC, op.scope(), t); return t.wrap_outputs( imperative::apply(op, t.unwrap_inputs(unified_inputs)), format); } @@ -410,7 +435,16 @@ ValueRefList concat_rule( if (!(format == FT::NHWC && auto_convert)) { return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), format); } - auto unified_inputs = unify_inputs_format(inputs, FT::NHWC, op.scope(), t); + ValueRefList unified_inputs(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + auto&& inp = inputs[i].cast(t.value_type()); + if (inp.format() != FT::NHWC && inp.value().is_scalar()) { + unified_inputs[i] = t.value_type().make(inp.value(), FT::NHWC); + } else { + unified_inputs[i] = inputs[i]; + } + } + unified_inputs = unify_inputs_format(unified_inputs, FT::NHWC, op.scope(), t); // TODO: handle 5D NHWC Tensor from group conv auto axis = op.axis; if (axis == 2 || axis == 3) { @@ -441,7 +475,7 @@ ValueRefList batchnorm_rule( const FormatTransformation& t) { auto&& inp_format = inputs[0].cast(t.value_type()).format(); if (inp_format == FT::NHWC) { - auto&& new_param = op.param(); + auto new_param = op.param(); new_param.param_dim = BatchNorm::ParamDim::DIM_111C; auto new_op = BatchNorm::make(new_param); return identity_rule_helper(*new_op, inputs, t); @@ -454,7 +488,7 @@ ValueRefList adaptive_pooling_rule( const FormatTransformation& t) { auto&& inp_format = inputs[0].cast(t.value_type()).format(); if (inp_format == FT::NHWC) { - auto&& new_param = op.param(); + auto new_param = op.param(); new_param.format = AdaptivePooling::Format::NHWC; auto new_op = AdaptivePooling::make(new_param, op.shape); return identity_rule_helper(*new_op, inputs, t); @@ -518,7 +552,7 @@ FOREACH_IDENTITY_OP(CREATE_IDENTITY_OP_RULE) const FormatTransformation& t) { \ auto&& inp_format = inputs[0].cast(t.value_type()).format(); \ if (inp_format == FT::NHWC) { \ - auto&& new_param = _op.param(); \ + auto new_param = _op.param(); \ new_param.format = Op::Format::NHWC; \ auto new_op = Op::make(new_param); \ return identity_rule_helper(*new_op, inputs, t); \ @@ -535,7 +569,7 @@ FOREACH_FORMAT_OP(CREATE_FORMAT_OP_RULE) const FormatTransformation& t) { \ auto&& inp_format = inputs[0].cast(t.value_type()).format(); \ if (inp_format == FT::NHWC) { \ - auto&& new_param = _op.param(); \ + auto new_param = _op.param(); \ new_param.format = Op::Format::NHWC; \ auto new_op = Op::make(new_param, _op.policy()); \ return identity_rule_helper(*new_op, inputs, t); \