Browse Source

refactor(mge/functional): matmul supports symbolic shape, batched mv multiply

GitOrigin-RevId: c4d8cf3306
release-1.1
Megvii Engine Team 4 years ago
parent
commit
1fa143ce87
2 changed files with 68 additions and 49 deletions
  1. +38
    -36
      imperative/python/megengine/functional/nn.py
  2. +30
    -13
      imperative/python/test/unit/functional/test_functional.py

+ 38
- 36
imperative/python/megengine/functional/nn.py View File

@@ -23,7 +23,7 @@ from ..tensor import Tensor
from .debug_param import get_conv_execution_strategy from .debug_param import get_conv_execution_strategy
from .distributed import all_reduce_sum from .distributed import all_reduce_sum
from .elemwise import exp, floor, log, log1p, maximum, minimum, relu from .elemwise import exp, floor, log, log1p, maximum, minimum, relu
from .math import argsort, max, sum
from .math import argsort, max, prod, sum
from .tensor import ( from .tensor import (
broadcast_to, broadcast_to,
concat, concat,
@@ -972,38 +972,42 @@ def matmul(
[28. 40.]] [28. 40.]]


""" """
remove_row, remove_col = False, False
inp1, inp2 = utils.convert_inputs(inp1, inp2) inp1, inp2 = utils.convert_inputs(inp1, inp2)

dim1, dim2 = inp1.ndim, inp2.ndim dim1, dim2 = inp1.ndim, inp2.ndim
# handle dim=1 cases, dot and matrix-vector multiplication
if dim1 == 1 and dim2 == 1: if dim1 == 1 and dim2 == 1:
return dot(inp1, inp2) return dot(inp1, inp2)
# the underlying matmul op requires input dims to be at least 2
if dim1 == 1:
inp1 = expand_dims(inp1, 0)
dim1 = 2
remove_row = True
if dim2 == 1:
inp2 = expand_dims(inp2, 1)
dim2 = 2
remove_col = True

batch_shape = None
shape1 = astensor1d(inp1.shape, inp1, dtype="int32", device=inp1.device)
shape2 = astensor1d(inp2.shape, inp2, dtype="int32", device=inp2.device)
if dim1 >= 3 or dim2 >= 3:
if dim1 == dim2:
assert (
shape1[:-2] == shape2[:-2]
).min(), "operands could not be broadcasted together."
if dim1 > dim2:
shape2 = concat([shape1[:-2], shape2[-2:]])
inp2 = broadcast_to(inp2, shape2)
if dim1 < dim2:
shape1 = concat([shape2[:-2], shape1[-2:]])
inp1 = broadcast_to(inp1, shape1)
batch_shape = shape1[:-2]
# compress inputs to 3d
inp1 = inp1.reshape(concat([prod(shape1[:-2]), shape1[-2:]]))
inp2 = inp2.reshape(concat([prod(shape2[:-2]), shape2[-2:]]))


shp = None
if dim1 > 3 or dim2 > 3:
shape1, shape2 = list(inp1.shape), list(inp2.shape)
if dim1 != dim2:
if dim1 < dim2:
shape1 = shape2[: dim2 - dim1] + shape1
inp1 = broadcast_to(inp1, shape1)
else:
shape2 = shape1[: dim1 - dim2] + shape2
inp2 = broadcast_to(inp2, shape2)
reshaped_batch_size = 1
for i in shape1[:-2]:
reshaped_batch_size *= i
inp1 = inp1.reshape(*([reshaped_batch_size] + shape1[-2:]))
inp2 = inp2.reshape(*([reshaped_batch_size] + shape2[-2:]))
op = builtin.BatchedMatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
)
shp = shape1[:-1] + shape2[-1:]
elif dim1 == 3 or dim2 == 3:
if dim2 < 3:
inp2 = broadcast_to(inp2, inp1.shape[:1] + inp2.shape)
elif dim1 < 3:
inp1 = broadcast_to(inp1, inp2.shape[:1] + inp1.shape)
op = builtin.BatchedMatrixMul( op = builtin.BatchedMatrixMul(
transposeA=transpose_a, transposeA=transpose_a,
transposeB=transpose_b, transposeB=transpose_b,
@@ -1011,12 +1015,6 @@ def matmul(
format=format, format=format,
) )
else: else:
if dim1 == 1:
shp = (inp2.shape[1],)
inp1 = expand_dims(inp1, 0)
if dim2 == 1:
shp = (inp1.shape[0],)
inp2 = expand_dims(inp2, 1)
op = builtin.MatrixMul( op = builtin.MatrixMul(
transposeA=transpose_a, transposeA=transpose_a,
transposeB=transpose_b, transposeB=transpose_b,
@@ -1025,8 +1023,12 @@ def matmul(
) )


(result,) = apply(op, inp1, inp2) (result,) = apply(op, inp1, inp2)
if shp is not None:
result = result.reshape(shp)
if batch_shape is not None:
result = result.reshape(concat([batch_shape, result.shape[-2:]]))
if remove_row:
result = squeeze(result, axis=-2)
if remove_col:
result = squeeze(result, axis=-1)
return result return result






+ 30
- 13
imperative/python/test/unit/functional/test_functional.py View File

@@ -77,26 +77,43 @@ def test_matmul():
opr_test(cases, F.matmul, ref_fn=np.matmul) opr_test(cases, F.matmul, ref_fn=np.matmul)


batch_size = 10 batch_size = 10
shape1 = (batch_size, 2, 3)
shape2 = (batch_size, 3, 4)
shape3 = (batch_size, 10, 4, 5)
shape1 = (2,)
shape2 = (batch_size, 2, 3)
shape3 = (batch_size, 3, 4)
shape4 = (batch_size, 10, 4, 2)
shape5 = (batch_size, 10, 2, 4)
data1 = np.random.random(shape1).astype("float32") data1 = np.random.random(shape1).astype("float32")
data2 = np.random.random(shape2).astype("float32") data2 = np.random.random(shape2).astype("float32")
data3 = np.random.random(shape3).astype("float32") data3 = np.random.random(shape3).astype("float32")
data4 = np.random.random(shape4).astype("float32")
data5 = np.random.random(shape5).astype("float32")


cases = [{"input": [data1, data2]}, {"input": [data2, data3]}]
for i in range(0, batch_size):

def compare_fn(x, y):
x.numpy()[i, ...] == y

cases = [
{"input": [data1, data2]},
{"input": [data2, data3]},
{"input": [data3, data4]},
{"input": [data4, data5]},
]
for _ in range(0, batch_size):
opr_test( opr_test(
cases,
F.matmul,
compare_fn=compare_fn,
ref_fn=lambda x, y: np.matmul(x[i, ...], y[i, ...]),
cases, F.matmul, ref_fn=np.matmul,
) )


opr_test(
[{"input": [data1, data4]}],
F.matmul,
ref_fn=lambda x, y: np.matmul(x, y.transpose(0, 1, 3, 2)),
transpose_b=True,
)

opr_test(
[{"input": [data3, data2]}],
F.matmul,
ref_fn=lambda x, y: np.matmul(x.transpose(0, 2, 1), y.transpose(0, 2, 1)),
transpose_a=True,
transpose_b=True,
)



def test_interpolate(): def test_interpolate():
def linear_interpolate(): def linear_interpolate():


Loading…
Cancel
Save