|
|
@@ -8,17 +8,21 @@ |
|
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
import collections |
|
|
|
import math |
|
|
|
from functools import lru_cache |
|
|
|
from typing import Optional, Sequence, Tuple, Union |
|
|
|
|
|
|
|
from ..core._imperative_rt.core2 import apply, dtype_promotion |
|
|
|
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder |
|
|
|
from ..core._trace_option import use_symbolic_shape |
|
|
|
from ..core.ops import builtin |
|
|
|
from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt |
|
|
|
from ..core.ops.special import Const |
|
|
|
from ..core.tensor import amp |
|
|
|
from ..core.tensor.utils import _normalize_axis, cast_tensors, setscalar |
|
|
|
from ..core.tensor.utils import _normalize_axis, cast_tensors, setscalar, subgraph |
|
|
|
from ..jit import exclude_from_trace |
|
|
|
from ..tensor import Tensor |
|
|
|
from .debug_param import get_execution_strategy |
|
|
|
from .elemwise import clip |
|
|
|
from .elemwise import clip, minimum |
|
|
|
from .tensor import broadcast_to, concat, expand_dims, squeeze |
|
|
|
|
|
|
|
__all__ = [ |
|
|
@@ -763,6 +767,216 @@ def matinv(inp: Tensor) -> Tensor: |
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
@lru_cache(maxsize=None) |
|
|
|
def _get_extentedMatrixMulOp( |
|
|
|
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, |
|
|
|
): |
|
|
|
@subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=3) |
|
|
|
def extentedMatrixMulOp(inputs, f, c): |
|
|
|
assert len(inputs) == 2 |
|
|
|
inp1, inp2 = inputs |
|
|
|
_dim1, _dim2 = dim1, dim2 |
|
|
|
|
|
|
|
def build_shape_head(shape, idx=-1): |
|
|
|
# shape[:idx] |
|
|
|
return f( |
|
|
|
builtin.Subtensor(items=[[0, False, True, False, False]]), |
|
|
|
shape, |
|
|
|
c(idx, "int32"), |
|
|
|
) |
|
|
|
|
|
|
|
def build_shape_tail(shape, idx=-1): |
|
|
|
# shape[idx:] |
|
|
|
return f( |
|
|
|
builtin.Subtensor(items=[[0, True, False, False, False]]), |
|
|
|
shape, |
|
|
|
c(idx, "int32"), |
|
|
|
) |
|
|
|
|
|
|
|
remove_row, remove_col = False, False |
|
|
|
if _dim1 == 1: |
|
|
|
_dim1 = 2 |
|
|
|
remove_row = True |
|
|
|
if _dim2 == 1: |
|
|
|
_dim2 = 2 |
|
|
|
remove_col = True |
|
|
|
|
|
|
|
if remove_row: |
|
|
|
inp1 = f(builtin.AddAxis(axis=[0,]), inp1) |
|
|
|
if remove_col: |
|
|
|
inp2 = f(builtin.AddAxis(axis=[1,]), inp2) |
|
|
|
|
|
|
|
shape1 = f(GetVarShape(), inp1) |
|
|
|
shape2 = f(GetVarShape(), inp2) |
|
|
|
if _dim1 > 2: |
|
|
|
inp1 = f( |
|
|
|
builtin.Reshape(), |
|
|
|
inp1, |
|
|
|
f( |
|
|
|
builtin.Concat(axis=0, comp_node=device), |
|
|
|
f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape1)), |
|
|
|
build_shape_tail(shape1), |
|
|
|
), |
|
|
|
) |
|
|
|
if _dim2 > 2: |
|
|
|
inp2 = f( |
|
|
|
builtin.Reshape(), |
|
|
|
inp2, |
|
|
|
f( |
|
|
|
builtin.Concat(axis=0, comp_node=device), |
|
|
|
f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape2)), |
|
|
|
build_shape_tail(shape2), |
|
|
|
), |
|
|
|
) |
|
|
|
op = builtin.MatrixMul( |
|
|
|
transposeA=transpose_a, |
|
|
|
transposeB=transpose_b, |
|
|
|
compute_mode=compute_mode, |
|
|
|
format=format, |
|
|
|
strategy=strategy, |
|
|
|
) |
|
|
|
result = f(op, inp1, inp2) |
|
|
|
result_shape = f(GetVarShape(), result) |
|
|
|
if _dim1 > 2: |
|
|
|
result = f( |
|
|
|
builtin.Reshape(), |
|
|
|
result, |
|
|
|
f( |
|
|
|
builtin.Concat(axis=0, comp_node=device), |
|
|
|
build_shape_head(shape1), |
|
|
|
build_shape_tail(result_shape), |
|
|
|
), |
|
|
|
) |
|
|
|
if _dim2 > 2: |
|
|
|
result = f( |
|
|
|
builtin.Reshape(), |
|
|
|
result, |
|
|
|
f( |
|
|
|
builtin.Concat(axis=0, comp_node=device), |
|
|
|
build_shape_head(shape2), |
|
|
|
build_shape_tail(result_shape), |
|
|
|
), |
|
|
|
) |
|
|
|
maxdim = _dim1 if _dim1 > _dim2 else _dim2 |
|
|
|
if remove_row: |
|
|
|
result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) |
|
|
|
if remove_col: |
|
|
|
result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) |
|
|
|
return (result,), (True,) |
|
|
|
|
|
|
|
return extentedMatrixMulOp |
|
|
|
|
|
|
|
|
|
|
|
@lru_cache(maxsize=None) |
|
|
|
def _get_extentedBatchedMatrixMulOp( |
|
|
|
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, |
|
|
|
): |
|
|
|
@subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=3) |
|
|
|
def extentedBatchedMatrixMulOp(inputs, f, c): |
|
|
|
assert len(inputs) == 2 |
|
|
|
inp1, inp2 = inputs |
|
|
|
_dim1, _dim2 = dim1, dim2 |
|
|
|
|
|
|
|
def build_shape_head(shape, idx=-2): |
|
|
|
# shape[:idx] |
|
|
|
return f( |
|
|
|
builtin.Subtensor(items=[[0, False, True, False, False]]), |
|
|
|
shape, |
|
|
|
c(idx, "int32"), |
|
|
|
) |
|
|
|
|
|
|
|
def build_shape_tail(shape, idx=-2): |
|
|
|
# shape[idx:] |
|
|
|
return f( |
|
|
|
builtin.Subtensor(items=[[0, True, False, False, False]]), |
|
|
|
shape, |
|
|
|
c(idx, "int32"), |
|
|
|
) |
|
|
|
|
|
|
|
remove_row, remove_col = False, False |
|
|
|
if _dim1 == 1: |
|
|
|
_dim1 = 2 |
|
|
|
remove_row = True |
|
|
|
if _dim2 == 1: |
|
|
|
_dim2 = 2 |
|
|
|
remove_col = True |
|
|
|
|
|
|
|
if remove_row: |
|
|
|
inp1 = f(builtin.AddAxis(axis=[0,]), inp1) |
|
|
|
if remove_col: |
|
|
|
inp2 = f(builtin.AddAxis(axis=[1,]), inp2) |
|
|
|
shape1 = f(GetVarShape(), inp1) |
|
|
|
shape2 = f(GetVarShape(), inp2) |
|
|
|
maxdim = _dim1 if _dim1 > _dim2 else _dim2 |
|
|
|
if _dim1 > _dim2: |
|
|
|
# broadcast |
|
|
|
shape2 = f( |
|
|
|
builtin.Concat(axis=0, comp_node=device), |
|
|
|
build_shape_head(shape1, idx=-_dim2), # shape1[:-_dim2] |
|
|
|
shape2, |
|
|
|
) |
|
|
|
inp2 = f(builtin.Broadcast(), inp2, shape2) |
|
|
|
batch_shape = build_shape_head(shape1) |
|
|
|
if _dim2 > _dim1: |
|
|
|
# broadcast |
|
|
|
shape1 = f( |
|
|
|
builtin.Concat(axis=0, comp_node=device), |
|
|
|
build_shape_head(shape2, idx=-_dim1), # shape2[:-_dim1] |
|
|
|
shape1, |
|
|
|
) |
|
|
|
inp1 = f(builtin.Broadcast(), inp1, shape1) |
|
|
|
batch_shape = build_shape_head(shape2) |
|
|
|
if _dim1 == _dim2: |
|
|
|
batch_shape = build_shape_head(shape1) |
|
|
|
|
|
|
|
# compress inputs to 3d |
|
|
|
if maxdim > 3: |
|
|
|
inp1 = f( |
|
|
|
builtin.Reshape(), |
|
|
|
inp1, |
|
|
|
f( |
|
|
|
builtin.Concat(axis=0, comp_node=device), |
|
|
|
f(builtin.Reduce(mode="product", axis=0), batch_shape), |
|
|
|
build_shape_tail(shape1), |
|
|
|
), |
|
|
|
) |
|
|
|
inp2 = f( |
|
|
|
builtin.Reshape(), |
|
|
|
inp2, |
|
|
|
f( |
|
|
|
builtin.Concat(axis=0, comp_node=device), |
|
|
|
f(builtin.Reduce(mode="product", axis=0), batch_shape), |
|
|
|
build_shape_tail(shape2), |
|
|
|
), |
|
|
|
) |
|
|
|
op = builtin.BatchedMatrixMul( |
|
|
|
transposeA=transpose_a, |
|
|
|
transposeB=transpose_b, |
|
|
|
compute_mode=compute_mode, |
|
|
|
format=format, |
|
|
|
strategy=strategy, |
|
|
|
) |
|
|
|
result = f(op, inp1, inp2) |
|
|
|
|
|
|
|
if maxdim > 3: |
|
|
|
result = f( |
|
|
|
builtin.Reshape(), |
|
|
|
result, |
|
|
|
f( |
|
|
|
builtin.Concat(axis=0, comp_node=device), |
|
|
|
batch_shape, |
|
|
|
build_shape_tail(f(GetVarShape(), result)), |
|
|
|
), |
|
|
|
) |
|
|
|
if remove_row: |
|
|
|
result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) |
|
|
|
if remove_col: |
|
|
|
result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) |
|
|
|
return (result,), (True,) |
|
|
|
|
|
|
|
return extentedBatchedMatrixMulOp |
|
|
|
|
|
|
|
|
|
|
|
def matmul( |
|
|
|
inp1: Tensor, |
|
|
|
inp2: Tensor, |
|
|
@@ -822,85 +1036,39 @@ def matmul( |
|
|
|
if inp2.dtype != dtype: |
|
|
|
inp2 = inp2.astype(dtype) |
|
|
|
|
|
|
|
remove_row, remove_col = False, False |
|
|
|
dim1, dim2 = inp1.ndim, inp2.ndim |
|
|
|
# handle dim=1 cases, dot and matrix-vector multiplication |
|
|
|
if dim1 == 1 and dim2 == 1: |
|
|
|
return dot(inp1, inp2) |
|
|
|
# the underlying matmul op requires input dims to be at least 2 |
|
|
|
if dim1 == 1: |
|
|
|
inp1 = expand_dims(inp1, 0) |
|
|
|
dim1 = 2 |
|
|
|
remove_row = True |
|
|
|
if dim2 == 1: |
|
|
|
inp2 = expand_dims(inp2, 1) |
|
|
|
dim2 = 2 |
|
|
|
remove_col = True |
|
|
|
|
|
|
|
batch_shape = None |
|
|
|
shape1 = inp1.shape |
|
|
|
shape2 = inp2.shape |
|
|
|
|
|
|
|
assert dim1 > 0 and dim2 > 0 |
|
|
|
maxdim = dim1 if dim1 > dim2 else dim2 |
|
|
|
if dim1 >= 3 or dim2 >= 3: |
|
|
|
if use_symbolic_shape(): |
|
|
|
if dim1 > dim2: |
|
|
|
shape2 = concat([shape1[:-2], shape2[-2:]]) |
|
|
|
inp2 = broadcast_to(inp2, shape2) |
|
|
|
if dim1 < dim2: |
|
|
|
shape1 = concat([shape2[:-2], shape1[-2:]]) |
|
|
|
inp1 = broadcast_to(inp1, shape1) |
|
|
|
if maxdim > 3: |
|
|
|
batch_shape = shape1[:-2] |
|
|
|
# compress inputs to 3d |
|
|
|
(inp1,) = apply( |
|
|
|
builtin.Reshape(), inp1, concat([prod(shape1[:-2]), shape1[-2:]]) |
|
|
|
) |
|
|
|
(inp2,) = apply( |
|
|
|
builtin.Reshape(), inp2, concat([prod(shape2[:-2]), shape2[-2:]]) |
|
|
|
) |
|
|
|
else: |
|
|
|
if dim1 > dim2: |
|
|
|
shape2 = shape1[:-2] + shape2[-2:] |
|
|
|
inp2 = broadcast_to(inp2, shape2) |
|
|
|
if dim1 < dim2: |
|
|
|
shape1 = shape2[:-2] + shape1[-2:] |
|
|
|
inp1 = broadcast_to(inp1, shape1) |
|
|
|
if maxdim > 3: |
|
|
|
batch_shape = shape1[:-2] |
|
|
|
# compress inputs to 3d |
|
|
|
inp1 = inp1.reshape((-1, shape1[-2], shape1[-1])) |
|
|
|
inp2 = inp2.reshape((-1, shape2[-2], shape2[-1])) |
|
|
|
|
|
|
|
op = builtin.BatchedMatrixMul( |
|
|
|
transposeA=transpose_a, |
|
|
|
transposeB=transpose_b, |
|
|
|
compute_mode=compute_mode, |
|
|
|
format=format, |
|
|
|
if dim1 == 1 and dim2 == 1: # dispatch to Dot |
|
|
|
return dot(inp1, inp2) |
|
|
|
elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul |
|
|
|
extentedMatrixMulOp = _get_extentedMatrixMulOp( |
|
|
|
inp1.device, |
|
|
|
inp1.dtype, |
|
|
|
dim1, |
|
|
|
dim2, |
|
|
|
transpose_a, |
|
|
|
transpose_b, |
|
|
|
compute_mode, |
|
|
|
format, |
|
|
|
strategy=get_execution_strategy(), |
|
|
|
) |
|
|
|
else: |
|
|
|
op = builtin.MatrixMul( |
|
|
|
transposeA=transpose_a, |
|
|
|
transposeB=transpose_b, |
|
|
|
compute_mode=compute_mode, |
|
|
|
format=format, |
|
|
|
(result,) = apply(extentedMatrixMulOp, inp1, inp2) |
|
|
|
return result |
|
|
|
else: # dispath to BatchedMatrixMul |
|
|
|
extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp( |
|
|
|
inp1.device, |
|
|
|
inp1.dtype, |
|
|
|
dim1, |
|
|
|
dim2, |
|
|
|
transpose_a, |
|
|
|
transpose_b, |
|
|
|
compute_mode, |
|
|
|
format, |
|
|
|
strategy=get_execution_strategy(), |
|
|
|
) |
|
|
|
|
|
|
|
(result,) = apply(op, inp1, inp2) |
|
|
|
if maxdim > 3: |
|
|
|
if use_symbolic_shape(): |
|
|
|
(result,) = apply( |
|
|
|
builtin.Reshape(), result, concat([batch_shape, result.shape[-2:]]) |
|
|
|
) |
|
|
|
else: |
|
|
|
result = result.reshape(batch_shape + result.shape[-2:]) |
|
|
|
if remove_row: |
|
|
|
result = squeeze(result, axis=-2) |
|
|
|
if remove_col: |
|
|
|
result = squeeze(result, axis=-1) |
|
|
|
return result |
|
|
|
(result,) = apply(extentedBatchedMatrixMulOp, inp1, inp2) |
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
def dot(inp1: Tensor, inp2: Tensor) -> Tensor: |
|
|
|