|
@@ -23,7 +23,7 @@ from ..tensor import Tensor |
|
|
from .debug_param import get_conv_execution_strategy |
|
|
from .debug_param import get_conv_execution_strategy |
|
|
from .distributed import all_reduce_sum |
|
|
from .distributed import all_reduce_sum |
|
|
from .elemwise import exp, floor, log, log1p, maximum, minimum, relu |
|
|
from .elemwise import exp, floor, log, log1p, maximum, minimum, relu |
|
|
from .math import argsort, max, sum |
|
|
|
|
|
|
|
|
from .math import argsort, max, prod, sum |
|
|
from .tensor import ( |
|
|
from .tensor import ( |
|
|
broadcast_to, |
|
|
broadcast_to, |
|
|
concat, |
|
|
concat, |
|
@@ -972,38 +972,42 @@ def matmul( |
|
|
[28. 40.]] |
|
|
[28. 40.]] |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
|
|
|
remove_row, remove_col = False, False |
|
|
inp1, inp2 = utils.convert_inputs(inp1, inp2) |
|
|
inp1, inp2 = utils.convert_inputs(inp1, inp2) |
|
|
|
|
|
|
|
|
dim1, dim2 = inp1.ndim, inp2.ndim |
|
|
dim1, dim2 = inp1.ndim, inp2.ndim |
|
|
|
|
|
# handle dim=1 cases, dot and matrix-vector multiplication |
|
|
if dim1 == 1 and dim2 == 1: |
|
|
if dim1 == 1 and dim2 == 1: |
|
|
return dot(inp1, inp2) |
|
|
return dot(inp1, inp2) |
|
|
|
|
|
# the underlying matmul op requires input dims to be at least 2 |
|
|
|
|
|
if dim1 == 1: |
|
|
|
|
|
inp1 = expand_dims(inp1, 0) |
|
|
|
|
|
dim1 = 2 |
|
|
|
|
|
remove_row = True |
|
|
|
|
|
if dim2 == 1: |
|
|
|
|
|
inp2 = expand_dims(inp2, 1) |
|
|
|
|
|
dim2 = 2 |
|
|
|
|
|
remove_col = True |
|
|
|
|
|
|
|
|
|
|
|
batch_shape = None |
|
|
|
|
|
shape1 = astensor1d(inp1.shape, inp1, dtype="int32", device=inp1.device) |
|
|
|
|
|
shape2 = astensor1d(inp2.shape, inp2, dtype="int32", device=inp2.device) |
|
|
|
|
|
if dim1 >= 3 or dim2 >= 3: |
|
|
|
|
|
if dim1 == dim2: |
|
|
|
|
|
assert ( |
|
|
|
|
|
shape1[:-2] == shape2[:-2] |
|
|
|
|
|
).min(), "operands could not be broadcasted together." |
|
|
|
|
|
if dim1 > dim2: |
|
|
|
|
|
shape2 = concat([shape1[:-2], shape2[-2:]]) |
|
|
|
|
|
inp2 = broadcast_to(inp2, shape2) |
|
|
|
|
|
if dim1 < dim2: |
|
|
|
|
|
shape1 = concat([shape2[:-2], shape1[-2:]]) |
|
|
|
|
|
inp1 = broadcast_to(inp1, shape1) |
|
|
|
|
|
batch_shape = shape1[:-2] |
|
|
|
|
|
# compress inputs to 3d |
|
|
|
|
|
inp1 = inp1.reshape(concat([prod(shape1[:-2]), shape1[-2:]])) |
|
|
|
|
|
inp2 = inp2.reshape(concat([prod(shape2[:-2]), shape2[-2:]])) |
|
|
|
|
|
|
|
|
shp = None |
|
|
|
|
|
if dim1 > 3 or dim2 > 3: |
|
|
|
|
|
shape1, shape2 = list(inp1.shape), list(inp2.shape) |
|
|
|
|
|
if dim1 != dim2: |
|
|
|
|
|
if dim1 < dim2: |
|
|
|
|
|
shape1 = shape2[: dim2 - dim1] + shape1 |
|
|
|
|
|
inp1 = broadcast_to(inp1, shape1) |
|
|
|
|
|
else: |
|
|
|
|
|
shape2 = shape1[: dim1 - dim2] + shape2 |
|
|
|
|
|
inp2 = broadcast_to(inp2, shape2) |
|
|
|
|
|
reshaped_batch_size = 1 |
|
|
|
|
|
for i in shape1[:-2]: |
|
|
|
|
|
reshaped_batch_size *= i |
|
|
|
|
|
inp1 = inp1.reshape(*([reshaped_batch_size] + shape1[-2:])) |
|
|
|
|
|
inp2 = inp2.reshape(*([reshaped_batch_size] + shape2[-2:])) |
|
|
|
|
|
op = builtin.BatchedMatrixMul( |
|
|
|
|
|
transposeA=transpose_a, |
|
|
|
|
|
transposeB=transpose_b, |
|
|
|
|
|
compute_mode=compute_mode, |
|
|
|
|
|
format=format, |
|
|
|
|
|
) |
|
|
|
|
|
shp = shape1[:-1] + shape2[-1:] |
|
|
|
|
|
elif dim1 == 3 or dim2 == 3: |
|
|
|
|
|
if dim2 < 3: |
|
|
|
|
|
inp2 = broadcast_to(inp2, inp1.shape[:1] + inp2.shape) |
|
|
|
|
|
elif dim1 < 3: |
|
|
|
|
|
inp1 = broadcast_to(inp1, inp2.shape[:1] + inp1.shape) |
|
|
|
|
|
op = builtin.BatchedMatrixMul( |
|
|
op = builtin.BatchedMatrixMul( |
|
|
transposeA=transpose_a, |
|
|
transposeA=transpose_a, |
|
|
transposeB=transpose_b, |
|
|
transposeB=transpose_b, |
|
@@ -1011,12 +1015,6 @@ def matmul( |
|
|
format=format, |
|
|
format=format, |
|
|
) |
|
|
) |
|
|
else: |
|
|
else: |
|
|
if dim1 == 1: |
|
|
|
|
|
shp = (inp2.shape[1],) |
|
|
|
|
|
inp1 = expand_dims(inp1, 0) |
|
|
|
|
|
if dim2 == 1: |
|
|
|
|
|
shp = (inp1.shape[0],) |
|
|
|
|
|
inp2 = expand_dims(inp2, 1) |
|
|
|
|
|
op = builtin.MatrixMul( |
|
|
op = builtin.MatrixMul( |
|
|
transposeA=transpose_a, |
|
|
transposeA=transpose_a, |
|
|
transposeB=transpose_b, |
|
|
transposeB=transpose_b, |
|
@@ -1025,8 +1023,12 @@ def matmul( |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
(result,) = apply(op, inp1, inp2) |
|
|
(result,) = apply(op, inp1, inp2) |
|
|
if shp is not None: |
|
|
|
|
|
result = result.reshape(shp) |
|
|
|
|
|
|
|
|
if batch_shape is not None: |
|
|
|
|
|
result = result.reshape(concat([batch_shape, result.shape[-2:]])) |
|
|
|
|
|
if remove_row: |
|
|
|
|
|
result = squeeze(result, axis=-2) |
|
|
|
|
|
if remove_col: |
|
|
|
|
|
result = squeeze(result, axis=-1) |
|
|
return result |
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|