GitOrigin-RevId: 4461f9a0d3
release-1.10
@@ -33,7 +33,7 @@ def convert_tensor_format(x: Tensor, inplace: bool = True): | |||||
if x.format != "nhwc": | if x.format != "nhwc": | ||||
if inplace: | if inplace: | ||||
# hostvalue should still be valid, so no d2h cost. | # hostvalue should still be valid, so no d2h cost. | ||||
data = x.numpy().transpose(*pattern) | |||||
data = x.numpy() | |||||
# reset will destroy existed backward grad | # reset will destroy existed backward grad | ||||
x[...] = Tensor(data, format="nhwc") | x[...] = Tensor(data, format="nhwc") | ||||
else: | else: | ||||
@@ -38,7 +38,7 @@ def test_basic(): | |||||
def _compare_nchw_nhwc(data, func, is_symbolic=None): | def _compare_nchw_nhwc(data, func, is_symbolic=None): | ||||
x1 = tensor(data) | x1 = tensor(data) | ||||
x2 = tensor(data.transpose(0, 2, 3, 1), format="nhwc") | |||||
x2 = tensor(data, format="nhwc") | |||||
if is_symbolic is not None: | if is_symbolic is not None: | ||||
func = trace(func, symbolic=is_symbolic) | func = trace(func, symbolic=is_symbolic) | ||||
out1 = func(x1) | out1 = func(x1) | ||||
@@ -247,8 +247,8 @@ def test_conv2d(is_symbolic): | |||||
if x.format == "nhwc": | if x.format == "nhwc": | ||||
x = F.conv2d( | x = F.conv2d( | ||||
x, | x, | ||||
weight=mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc"), | |||||
bias=mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc"), | |||||
weight=mge.tensor(np.ones((3, 2, 1, 1)), format="nhwc"), | |||||
bias=mge.tensor(np.ones((1, 3, 1, 1)), format="nhwc"), | |||||
) | ) | ||||
assert x.format == "nhwc" | assert x.format == "nhwc" | ||||
return x.numpy() | return x.numpy() | ||||
@@ -265,8 +265,8 @@ def test_group_conv2d(is_symbolic): | |||||
if x.format == "nhwc": | if x.format == "nhwc": | ||||
x = F.conv2d( | x = F.conv2d( | ||||
x, | x, | ||||
weight=mge.tensor(np.ones((2, 2, 1, 1, 2)), format="nhwc"), | |||||
bias=mge.tensor(np.ones((1, 1, 1, 4)), format="nhwc"), | |||||
weight=mge.tensor(np.ones((2, 2, 2, 1, 1)), format="nhwc"), | |||||
bias=mge.tensor(np.ones((1, 4, 1, 1)), format="nhwc"), | |||||
groups=2, | groups=2, | ||||
) | ) | ||||
assert x.format == "nhwc" | assert x.format == "nhwc" | ||||
@@ -286,10 +286,10 @@ def test_bn(is_symbolic): | |||||
if x.format == "nhwc": | if x.format == "nhwc": | ||||
oups = F.batch_norm( | oups = F.batch_norm( | ||||
x.astype("float32"), | x.astype("float32"), | ||||
running_mean=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), | |||||
running_var=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), | |||||
weight=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), | |||||
bias=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), | |||||
running_mean=mge.tensor(np.ones((1, 2, 1, 1)), format="nhwc"), | |||||
running_var=mge.tensor(np.ones((1, 2, 1, 1)), format="nhwc"), | |||||
weight=mge.tensor(np.ones((1, 2, 1, 1)), format="nhwc"), | |||||
bias=mge.tensor(np.ones((1, 2, 1, 1)), format="nhwc"), | |||||
training=True, | training=True, | ||||
inplace=False, | inplace=False, | ||||
) | ) | ||||
@@ -27,7 +27,11 @@ TypedValueRef<FormattedTensorValue> FormatTransformation::to( | |||||
pattern = {0, 3, 1, 2}; | pattern = {0, 3, 1, 2}; | ||||
} | } | ||||
} else if ((format == FT::NCHW || format == FT::DEFAULT) && target == FT::NHWC) { | } else if ((format == FT::NCHW || format == FT::DEFAULT) && target == FT::NHWC) { | ||||
pattern = {0, 2, 3, 1}; | |||||
if (tensor.value().shape().cast<ShapeValue>().ndim == 5) { | |||||
pattern = {0, 1, 3, 4, 2}; | |||||
} else { | |||||
pattern = {0, 2, 3, 1}; | |||||
} | |||||
} else { | } else { | ||||
mgb_throw( | mgb_throw( | ||||
MegBrainError, "Unsupport format conversion from %s to %s", | MegBrainError, "Unsupport format conversion from %s to %s", | ||||
@@ -572,8 +576,13 @@ ValueRefList FormatTransformation::apply_transformation( | |||||
} | } | ||||
} else if (auto* create_tensor = op.as<CreateTensor>()) { | } else if (auto* create_tensor = op.as<CreateTensor>()) { | ||||
auto format = create_tensor->format(); | auto format = create_tensor->format(); | ||||
// TODO: add dimshuffle for nhwc format | |||||
return {wrap_output(imperative::apply(op, inputs)[0], format)}; | |||||
if (format == FT::NHWC) { | |||||
auto output = wrap_output(imperative::apply(op, inputs)[0]); | |||||
output = to(output.cast(m_value_type), FT::NHWC, ""); | |||||
return {output}; | |||||
} else { | |||||
return {wrap_output(imperative::apply(op, inputs)[0], format)}; | |||||
} | |||||
} else if (auto* get_attr = op.as<GetAttr>()) { | } else if (auto* get_attr = op.as<GetAttr>()) { | ||||
auto&& input = inputs.item(); | auto&& input = inputs.item(); | ||||
if (!input.is(m_value_type)) { | if (!input.is(m_value_type)) { | ||||