|
|
@@ -10,6 +10,7 @@ |
|
|
|
from typing import Optional, Sequence, Tuple, Union |
|
|
|
|
|
|
|
from ..core._imperative_rt import CompNode |
|
|
|
from ..core._trace_option import use_symbolic_shape |
|
|
|
from ..core.ops import builtin |
|
|
|
from ..core.ops._internal import param_defs as P |
|
|
|
from ..core.ops.builtin import BatchNorm |
|
|
@@ -1015,23 +1016,39 @@ def matmul( |
|
|
|
remove_col = True |
|
|
|
|
|
|
|
batch_shape = None |
|
|
|
shape1 = astensor1d(inp1.shape, inp1, dtype="int32", device=inp1.device) |
|
|
|
shape2 = astensor1d(inp2.shape, inp2, dtype="int32", device=inp2.device) |
|
|
|
shape1 = inp1.shape |
|
|
|
shape2 = inp2.shape |
|
|
|
|
|
|
|
maxdim = dim1 if dim1 > dim2 else dim2 |
|
|
|
if dim1 >= 3 or dim2 >= 3: |
|
|
|
if dim1 == dim2: |
|
|
|
assert ( |
|
|
|
shape1[:-2] == shape2[:-2] |
|
|
|
).min(), "operands could not be broadcasted together." |
|
|
|
if dim1 > dim2: |
|
|
|
shape2 = concat([shape1[:-2], shape2[-2:]]) |
|
|
|
inp2 = broadcast_to(inp2, shape2) |
|
|
|
if dim1 < dim2: |
|
|
|
shape1 = concat([shape2[:-2], shape1[-2:]]) |
|
|
|
inp1 = broadcast_to(inp1, shape1) |
|
|
|
batch_shape = shape1[:-2] |
|
|
|
# compress inputs to 3d |
|
|
|
inp1 = inp1.reshape(concat([prod(shape1[:-2]), shape1[-2:]])) |
|
|
|
inp2 = inp2.reshape(concat([prod(shape2[:-2]), shape2[-2:]])) |
|
|
|
if use_symbolic_shape(): |
|
|
|
if dim1 > dim2: |
|
|
|
shape2 = concat([shape1[:-2], shape2[-2:]]) |
|
|
|
inp2 = broadcast_to(inp2, shape2) |
|
|
|
if dim1 < dim2: |
|
|
|
shape1 = concat([shape2[:-2], shape1[-2:]]) |
|
|
|
inp1 = broadcast_to(inp1, shape1) |
|
|
|
if maxdim > 3: |
|
|
|
batch_shape = shape1[:-2] |
|
|
|
# compress inputs to 3d |
|
|
|
(inp1,) = apply( |
|
|
|
builtin.Reshape(), inp1, concat([prod(shape1[:-2]), shape1[-2:]]) |
|
|
|
) |
|
|
|
(inp2,) = apply( |
|
|
|
builtin.Reshape(), inp2, concat([prod(shape2[:-2]), shape2[-2:]]) |
|
|
|
) |
|
|
|
else: |
|
|
|
if dim1 > dim2: |
|
|
|
shape2 = shape1[:-2] + shape2[-2:] |
|
|
|
inp2 = broadcast_to(inp2, shape2) |
|
|
|
if dim1 < dim2: |
|
|
|
shape1 = shape2[:-2] + shape1[-2:] |
|
|
|
inp1 = broadcast_to(inp1, shape1) |
|
|
|
if maxdim > 3: |
|
|
|
batch_shape = shape1[:-2] |
|
|
|
# compress inputs to 3d |
|
|
|
inp1 = inp1.reshape((-1, shape1[-2], shape1[-1])) |
|
|
|
inp2 = inp2.reshape((-1, shape2[-2], shape2[-1])) |
|
|
|
|
|
|
|
op = builtin.BatchedMatrixMul( |
|
|
|
transposeA=transpose_a, |
|
|
@@ -1048,8 +1065,13 @@ def matmul( |
|
|
|
) |
|
|
|
|
|
|
|
(result,) = apply(op, inp1, inp2) |
|
|
|
if batch_shape is not None: |
|
|
|
result = result.reshape(concat([batch_shape, result.shape[-2:]])) |
|
|
|
if maxdim > 3: |
|
|
|
if use_symbolic_shape(): |
|
|
|
(result,) = apply( |
|
|
|
builtin.Reshape(), result, concat([batch_shape, result.shape[-2:]]) |
|
|
|
) |
|
|
|
else: |
|
|
|
result = result.reshape(batch_shape + result.shape[-2:]) |
|
|
|
if remove_row: |
|
|
|
result = squeeze(result, axis=-2) |
|
|
|
if remove_col: |
|
|
|