Browse Source

perf(mge/functional): reduce matmul python overhead

GitOrigin-RevId: 738d0da10e
release-1.1
Megvii Engine Team 4 years ago
parent
commit
1e0fb127d1
2 changed files with 43 additions and 20 deletions
  1. +3
    -2
      imperative/python/megengine/core/tensor/tensor_wrapper.py
  2. +40
    -18
      imperative/python/megengine/functional/nn.py

+ 3
- 2
imperative/python/megengine/core/tensor/tensor_wrapper.py View File

@@ -345,9 +345,10 @@ class ArrayMethodMixin(abc.ABC):
@property
def ndim(self):
shape = self.shape
# XXX: assume ndim is not changed during trace
if isinstance(shape, self.__class__):
shape = shape.numpy()
# XXX: assume ndim is not changed during trace
ndim = shape.__wrapped__.shape[0]
return ndim
return len(shape)

@property


+ 40
- 18
imperative/python/megengine/functional/nn.py View File

@@ -10,6 +10,7 @@
from typing import Optional, Sequence, Tuple, Union

from ..core._imperative_rt import CompNode
from ..core._trace_option import use_symbolic_shape
from ..core.ops import builtin
from ..core.ops._internal import param_defs as P
from ..core.ops.builtin import BatchNorm
@@ -1015,23 +1016,39 @@ def matmul(
remove_col = True

batch_shape = None
shape1 = astensor1d(inp1.shape, inp1, dtype="int32", device=inp1.device)
shape2 = astensor1d(inp2.shape, inp2, dtype="int32", device=inp2.device)
shape1 = inp1.shape
shape2 = inp2.shape

maxdim = dim1 if dim1 > dim2 else dim2
if dim1 >= 3 or dim2 >= 3:
if dim1 == dim2:
assert (
shape1[:-2] == shape2[:-2]
).min(), "operands could not be broadcasted together."
if dim1 > dim2:
shape2 = concat([shape1[:-2], shape2[-2:]])
inp2 = broadcast_to(inp2, shape2)
if dim1 < dim2:
shape1 = concat([shape2[:-2], shape1[-2:]])
inp1 = broadcast_to(inp1, shape1)
batch_shape = shape1[:-2]
# compress inputs to 3d
inp1 = inp1.reshape(concat([prod(shape1[:-2]), shape1[-2:]]))
inp2 = inp2.reshape(concat([prod(shape2[:-2]), shape2[-2:]]))
if use_symbolic_shape():
if dim1 > dim2:
shape2 = concat([shape1[:-2], shape2[-2:]])
inp2 = broadcast_to(inp2, shape2)
if dim1 < dim2:
shape1 = concat([shape2[:-2], shape1[-2:]])
inp1 = broadcast_to(inp1, shape1)
if maxdim > 3:
batch_shape = shape1[:-2]
# compress inputs to 3d
(inp1,) = apply(
builtin.Reshape(), inp1, concat([prod(shape1[:-2]), shape1[-2:]])
)
(inp2,) = apply(
builtin.Reshape(), inp2, concat([prod(shape2[:-2]), shape2[-2:]])
)
else:
if dim1 > dim2:
shape2 = shape1[:-2] + shape2[-2:]
inp2 = broadcast_to(inp2, shape2)
if dim1 < dim2:
shape1 = shape2[:-2] + shape1[-2:]
inp1 = broadcast_to(inp1, shape1)
if maxdim > 3:
batch_shape = shape1[:-2]
# compress inputs to 3d
inp1 = inp1.reshape((-1, shape1[-2], shape1[-1]))
inp2 = inp2.reshape((-1, shape2[-2], shape2[-1]))

op = builtin.BatchedMatrixMul(
transposeA=transpose_a,
@@ -1048,8 +1065,13 @@ def matmul(
)

(result,) = apply(op, inp1, inp2)
if batch_shape is not None:
result = result.reshape(concat([batch_shape, result.shape[-2:]]))
if maxdim > 3:
if use_symbolic_shape():
(result,) = apply(
builtin.Reshape(), result, concat([batch_shape, result.shape[-2:]])
)
else:
result = result.reshape(batch_shape + result.shape[-2:])
if remove_row:
result = squeeze(result, axis=-2)
if remove_col:


Loading…
Cancel
Save