Browse Source

feat(imperative/amp): add dimshuffle before creating nhwc tensor

GitOrigin-RevId: 4461f9a0d3
release-1.10
Megvii Engine Team 3 years ago
parent
commit
c9e56f4987
3 changed files with 22 additions and 13 deletions
  1. +1
    -1
      imperative/python/megengine/amp/convert_format.py
  2. +9
    -9
      imperative/python/test/unit/core/test_formatted_tensor.py
  3. +12
    -3
      imperative/src/impl/transformations/format.cpp

+ 1
- 1
imperative/python/megengine/amp/convert_format.py View File

@@ -33,7 +33,7 @@ def convert_tensor_format(x: Tensor, inplace: bool = True):
if x.format != "nhwc":
if inplace:
# hostvalue should still be valid, so no d2h cost.
data = x.numpy().transpose(*pattern)
data = x.numpy()
# reset will destroy existed backward grad
x[...] = Tensor(data, format="nhwc")
else:


+ 9
- 9
imperative/python/test/unit/core/test_formatted_tensor.py View File

@@ -38,7 +38,7 @@ def test_basic():

def _compare_nchw_nhwc(data, func, is_symbolic=None):
x1 = tensor(data)
x2 = tensor(data.transpose(0, 2, 3, 1), format="nhwc")
x2 = tensor(data, format="nhwc")
if is_symbolic is not None:
func = trace(func, symbolic=is_symbolic)
out1 = func(x1)
@@ -247,8 +247,8 @@ def test_conv2d(is_symbolic):
if x.format == "nhwc":
x = F.conv2d(
x,
weight=mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc"),
bias=mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc"),
weight=mge.tensor(np.ones((3, 2, 1, 1)), format="nhwc"),
bias=mge.tensor(np.ones((1, 3, 1, 1)), format="nhwc"),
)
assert x.format == "nhwc"
return x.numpy()
@@ -265,8 +265,8 @@ def test_group_conv2d(is_symbolic):
if x.format == "nhwc":
x = F.conv2d(
x,
weight=mge.tensor(np.ones((2, 2, 1, 1, 2)), format="nhwc"),
bias=mge.tensor(np.ones((1, 1, 1, 4)), format="nhwc"),
weight=mge.tensor(np.ones((2, 2, 2, 1, 1)), format="nhwc"),
bias=mge.tensor(np.ones((1, 4, 1, 1)), format="nhwc"),
groups=2,
)
assert x.format == "nhwc"
@@ -286,10 +286,10 @@ def test_bn(is_symbolic):
if x.format == "nhwc":
oups = F.batch_norm(
x.astype("float32"),
running_mean=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"),
running_var=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"),
weight=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"),
bias=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"),
running_mean=mge.tensor(np.ones((1, 2, 1, 1)), format="nhwc"),
running_var=mge.tensor(np.ones((1, 2, 1, 1)), format="nhwc"),
weight=mge.tensor(np.ones((1, 2, 1, 1)), format="nhwc"),
bias=mge.tensor(np.ones((1, 2, 1, 1)), format="nhwc"),
training=True,
inplace=False,
)


+ 12
- 3
imperative/src/impl/transformations/format.cpp View File

@@ -27,7 +27,11 @@ TypedValueRef<FormattedTensorValue> FormatTransformation::to(
pattern = {0, 3, 1, 2};
}
} else if ((format == FT::NCHW || format == FT::DEFAULT) && target == FT::NHWC) {
pattern = {0, 2, 3, 1};
if (tensor.value().shape().cast<ShapeValue>().ndim == 5) {
pattern = {0, 1, 3, 4, 2};
} else {
pattern = {0, 2, 3, 1};
}
} else {
mgb_throw(
MegBrainError, "Unsupport format conversion from %s to %s",
@@ -572,8 +576,13 @@ ValueRefList FormatTransformation::apply_transformation(
}
} else if (auto* create_tensor = op.as<CreateTensor>()) {
auto format = create_tensor->format();
// TODO: add dimshuffle for nhwc format
return {wrap_output(imperative::apply(op, inputs)[0], format)};
if (format == FT::NHWC) {
auto output = wrap_output(imperative::apply(op, inputs)[0]);
output = to(output.cast(m_value_type), FT::NHWC, "");
return {output};
} else {
return {wrap_output(imperative::apply(op, inputs)[0], format)};
}
} else if (auto* get_attr = op.as<GetAttr>()) {
auto&& input = inputs.item();
if (!input.is(m_value_type)) {


Loading…
Cancel
Save