|
|
@@ -38,7 +38,7 @@ def test_basic(): |
|
|
|
|
|
|
|
def _compare_nchw_nhwc(data, func, is_symbolic=None): |
|
|
|
x1 = tensor(data) |
|
|
|
x2 = tensor(data.transpose(0, 2, 3, 1), format="nhwc") |
|
|
|
x2 = tensor(data, format="nhwc") |
|
|
|
if is_symbolic is not None: |
|
|
|
func = trace(func, symbolic=is_symbolic) |
|
|
|
out1 = func(x1) |
|
|
@@ -247,8 +247,8 @@ def test_conv2d(is_symbolic): |
|
|
|
if x.format == "nhwc": |
|
|
|
x = F.conv2d( |
|
|
|
x, |
|
|
|
weight=mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc"), |
|
|
|
bias=mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc"), |
|
|
|
weight=mge.tensor(np.ones((3, 2, 1, 1)), format="nhwc"), |
|
|
|
bias=mge.tensor(np.ones((1, 3, 1, 1)), format="nhwc"), |
|
|
|
) |
|
|
|
assert x.format == "nhwc" |
|
|
|
return x.numpy() |
|
|
@@ -265,8 +265,8 @@ def test_group_conv2d(is_symbolic): |
|
|
|
if x.format == "nhwc": |
|
|
|
x = F.conv2d( |
|
|
|
x, |
|
|
|
weight=mge.tensor(np.ones((2, 2, 1, 1, 2)), format="nhwc"), |
|
|
|
bias=mge.tensor(np.ones((1, 1, 1, 4)), format="nhwc"), |
|
|
|
weight=mge.tensor(np.ones((2, 2, 2, 1, 1)), format="nhwc"), |
|
|
|
bias=mge.tensor(np.ones((1, 4, 1, 1)), format="nhwc"), |
|
|
|
groups=2, |
|
|
|
) |
|
|
|
assert x.format == "nhwc" |
|
|
@@ -286,10 +286,10 @@ def test_bn(is_symbolic): |
|
|
|
if x.format == "nhwc": |
|
|
|
oups = F.batch_norm( |
|
|
|
x.astype("float32"), |
|
|
|
running_mean=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), |
|
|
|
running_var=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), |
|
|
|
weight=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), |
|
|
|
bias=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), |
|
|
|
running_mean=mge.tensor(np.ones((1, 2, 1, 1)), format="nhwc"), |
|
|
|
running_var=mge.tensor(np.ones((1, 2, 1, 1)), format="nhwc"), |
|
|
|
weight=mge.tensor(np.ones((1, 2, 1, 1)), format="nhwc"), |
|
|
|
bias=mge.tensor(np.ones((1, 2, 1, 1)), format="nhwc"), |
|
|
|
training=True, |
|
|
|
inplace=False, |
|
|
|
) |
|
|
|