diff --git a/imperative/python/megengine/core/tensor/tensor_wrapper.py b/imperative/python/megengine/core/tensor/tensor_wrapper.py index a323aefe..9bf4bd1f 100644 --- a/imperative/python/megengine/core/tensor/tensor_wrapper.py +++ b/imperative/python/megengine/core/tensor/tensor_wrapper.py @@ -345,9 +345,10 @@ class ArrayMethodMixin(abc.ABC): @property def ndim(self): shape = self.shape - # XXX: assume ndim is not changed during trace if isinstance(shape, self.__class__): - shape = shape.numpy() + # XXX: assume ndim is not changed during trace + ndim = shape.__wrapped__.shape[0] + return ndim return len(shape) @property diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 875118a0..421defd6 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -10,6 +10,7 @@ from typing import Optional, Sequence, Tuple, Union from ..core._imperative_rt import CompNode +from ..core._trace_option import use_symbolic_shape from ..core.ops import builtin from ..core.ops._internal import param_defs as P from ..core.ops.builtin import BatchNorm @@ -1015,23 +1016,39 @@ def matmul( remove_col = True batch_shape = None - shape1 = astensor1d(inp1.shape, inp1, dtype="int32", device=inp1.device) - shape2 = astensor1d(inp2.shape, inp2, dtype="int32", device=inp2.device) + shape1 = inp1.shape + shape2 = inp2.shape + + maxdim = dim1 if dim1 > dim2 else dim2 if dim1 >= 3 or dim2 >= 3: - if dim1 == dim2: - assert ( - shape1[:-2] == shape2[:-2] - ).min(), "operands could not be broadcasted together." - if dim1 > dim2: - shape2 = concat([shape1[:-2], shape2[-2:]]) - inp2 = broadcast_to(inp2, shape2) - if dim1 < dim2: - shape1 = concat([shape2[:-2], shape1[-2:]]) - inp1 = broadcast_to(inp1, shape1) - batch_shape = shape1[:-2] - # compress inputs to 3d - inp1 = inp1.reshape(concat([prod(shape1[:-2]), shape1[-2:]])) - inp2 = inp2.reshape(concat([prod(shape2[:-2]), shape2[-2:]])) + if use_symbolic_shape(): + if dim1 > dim2: + shape2 = concat([shape1[:-2], shape2[-2:]]) + inp2 = broadcast_to(inp2, shape2) + if dim1 < dim2: + shape1 = concat([shape2[:-2], shape1[-2:]]) + inp1 = broadcast_to(inp1, shape1) + if maxdim > 3: + batch_shape = shape1[:-2] + # compress inputs to 3d + (inp1,) = apply( + builtin.Reshape(), inp1, concat([prod(shape1[:-2]), shape1[-2:]]) + ) + (inp2,) = apply( + builtin.Reshape(), inp2, concat([prod(shape2[:-2]), shape2[-2:]]) + ) + else: + if dim1 > dim2: + shape2 = shape1[:-2] + shape2[-2:] + inp2 = broadcast_to(inp2, shape2) + if dim1 < dim2: + shape1 = shape2[:-2] + shape1[-2:] + inp1 = broadcast_to(inp1, shape1) + if maxdim > 3: + batch_shape = shape1[:-2] + # compress inputs to 3d + inp1 = inp1.reshape((-1, shape1[-2], shape1[-1])) + inp2 = inp2.reshape((-1, shape2[-2], shape2[-1])) op = builtin.BatchedMatrixMul( transposeA=transpose_a, @@ -1048,8 +1065,13 @@ def matmul( ) (result,) = apply(op, inp1, inp2) - if batch_shape is not None: - result = result.reshape(concat([batch_shape, result.shape[-2:]])) + if maxdim > 3: + if use_symbolic_shape(): + (result,) = apply( + builtin.Reshape(), result, concat([batch_shape, result.shape[-2:]]) + ) + else: + result = result.reshape(batch_shape + result.shape[-2:]) if remove_row: result = squeeze(result, axis=-2) if remove_col: