GitOrigin-RevId: bf6136cc1a
tags/v1.9.0
@@ -8,6 +8,7 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import abc | |||
import collections | |||
from functools import lru_cache | |||
from typing import Union | |||
import numpy as np | |||
@@ -24,8 +25,8 @@ from .utils import ( | |||
astype, | |||
cast_tensors, | |||
convert_inputs, | |||
isscalar, | |||
make_shape_tuple, | |||
subgraph, | |||
) | |||
_ElwMod = builtin.Elemwise.Mode | |||
@@ -73,23 +74,292 @@ def _elwise(*args, mode): | |||
return _elwise_apply(args, mode) | |||
def _matmul(inp1, inp2): | |||
@lru_cache(maxsize=None) | |||
def _get_extentedMatrixMulOp( | |||
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, | |||
): | |||
@subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=2) | |||
def extentedMatrixMulOp(inputs, f, c): | |||
assert len(inputs) == 2 | |||
inp1, inp2 = inputs | |||
_dim1, _dim2 = dim1, dim2 | |||
def build_shape_head(shape, idx=-1): | |||
# shape[:idx] | |||
return f( | |||
builtin.Subtensor(items=[[0, False, True, False, False]]), | |||
shape, | |||
c(idx, "int32"), | |||
) | |||
def build_shape_tail(shape, idx=-1): | |||
# shape[idx:] | |||
return f( | |||
builtin.Subtensor(items=[[0, True, False, False, False]]), | |||
shape, | |||
c(idx, "int32"), | |||
) | |||
remove_row, remove_col = False, False | |||
if _dim1 == 1: | |||
_dim1 = 2 | |||
remove_row = True | |||
if _dim2 == 1: | |||
_dim2 = 2 | |||
remove_col = True | |||
if remove_row: | |||
inp1 = f(builtin.AddAxis(axis=[0,]), inp1) | |||
if remove_col: | |||
inp2 = f(builtin.AddAxis(axis=[1,]), inp2) | |||
shape1 = f(builtin.GetVarShape(), inp1) | |||
shape2 = f(builtin.GetVarShape(), inp2) | |||
if _dim1 > 2: | |||
inp1 = f( | |||
builtin.Reshape(), | |||
inp1, | |||
f( | |||
builtin.Concat(axis=0, comp_node=device), | |||
f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape1)), | |||
build_shape_tail(shape1), | |||
), | |||
) | |||
if _dim2 > 2: | |||
inp2 = f( | |||
builtin.Reshape(), | |||
inp2, | |||
f( | |||
builtin.Concat(axis=0, comp_node=device), | |||
f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape2)), | |||
build_shape_tail(shape2), | |||
), | |||
) | |||
op = builtin.MatrixMul( | |||
transposeA=transpose_a, | |||
transposeB=transpose_b, | |||
compute_mode=compute_mode, | |||
format=format, | |||
strategy=strategy.value, | |||
) | |||
result = f(op, inp1, inp2) | |||
result_shape = f(builtin.GetVarShape(), result) | |||
if _dim1 > 2: | |||
result = f( | |||
builtin.Reshape(), | |||
result, | |||
f( | |||
builtin.Concat(axis=0, comp_node=device), | |||
build_shape_head(shape1), | |||
build_shape_tail(result_shape), | |||
), | |||
) | |||
if _dim2 > 2: | |||
result = f( | |||
builtin.Reshape(), | |||
result, | |||
f( | |||
builtin.Concat(axis=0, comp_node=device), | |||
build_shape_head(shape2), | |||
build_shape_tail(result_shape), | |||
), | |||
) | |||
maxdim = _dim1 if _dim1 > _dim2 else _dim2 | |||
if remove_row: | |||
result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) | |||
if remove_col: | |||
result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) | |||
return (result,), (True,) | |||
return extentedMatrixMulOp | |||
@lru_cache(maxsize=None) | |||
def _get_extentedBatchedMatrixMulOp( | |||
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, | |||
): | |||
@subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=2) | |||
def extentedBatchedMatrixMulOp(inputs, f, c): | |||
assert len(inputs) == 2 | |||
inp1, inp2 = inputs | |||
_dim1, _dim2 = dim1, dim2 | |||
def build_shape_head(shape, idx=-2): | |||
# shape[:idx] | |||
return f( | |||
builtin.Subtensor(items=[[0, False, True, False, False]]), | |||
shape, | |||
c(idx, "int32"), | |||
) | |||
def build_shape_tail(shape, idx=-2): | |||
# shape[idx:] | |||
return f( | |||
builtin.Subtensor(items=[[0, True, False, False, False]]), | |||
shape, | |||
c(idx, "int32"), | |||
) | |||
remove_row, remove_col = False, False | |||
if _dim1 == 1: | |||
_dim1 = 2 | |||
remove_row = True | |||
if _dim2 == 1: | |||
_dim2 = 2 | |||
remove_col = True | |||
if remove_row: | |||
inp1 = f(builtin.AddAxis(axis=[0,]), inp1) | |||
if remove_col: | |||
inp2 = f(builtin.AddAxis(axis=[1,]), inp2) | |||
shape1 = f(builtin.GetVarShape(), inp1) | |||
shape2 = f(builtin.GetVarShape(), inp2) | |||
maxdim = _dim1 if _dim1 > _dim2 else _dim2 | |||
if _dim1 > _dim2: | |||
# broadcast | |||
shape2 = f( | |||
builtin.Concat(axis=0, comp_node=device), | |||
build_shape_head(shape1, idx=-_dim2), # shape1[:-_dim2] | |||
shape2, | |||
) | |||
inp2 = f(builtin.Broadcast(), inp2, shape2) | |||
batch_shape = build_shape_head(shape1) | |||
if _dim2 > _dim1: | |||
# broadcast | |||
shape1 = f( | |||
builtin.Concat(axis=0, comp_node=device), | |||
build_shape_head(shape2, idx=-_dim1), # shape2[:-_dim1] | |||
shape1, | |||
) | |||
inp1 = f(builtin.Broadcast(), inp1, shape1) | |||
batch_shape = build_shape_head(shape2) | |||
if _dim1 == _dim2: | |||
batch_shape = build_shape_head(shape1) | |||
# compress inputs to 3d | |||
if maxdim > 3: | |||
inp1 = f( | |||
builtin.Reshape(), | |||
inp1, | |||
f( | |||
builtin.Concat(axis=0, comp_node=device), | |||
f(builtin.Reduce(mode="product", axis=0), batch_shape), | |||
build_shape_tail(shape1), | |||
), | |||
) | |||
inp2 = f( | |||
builtin.Reshape(), | |||
inp2, | |||
f( | |||
builtin.Concat(axis=0, comp_node=device), | |||
f(builtin.Reduce(mode="product", axis=0), batch_shape), | |||
build_shape_tail(shape2), | |||
), | |||
) | |||
op = builtin.BatchedMatrixMul( | |||
transposeA=transpose_a, | |||
transposeB=transpose_b, | |||
compute_mode=compute_mode, | |||
format=format, | |||
strategy=strategy.value, | |||
) | |||
result = f(op, inp1, inp2) | |||
if maxdim > 3: | |||
result = f( | |||
builtin.Reshape(), | |||
result, | |||
f( | |||
builtin.Concat(axis=0, comp_node=device), | |||
batch_shape, | |||
build_shape_tail(f(builtin.GetVarShape(), result)), | |||
), | |||
) | |||
if remove_row: | |||
result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) | |||
if remove_col: | |||
result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) | |||
return (result,), (True,) | |||
return extentedBatchedMatrixMulOp | |||
class _Hashable: | |||
def __init__(self, value) -> None: | |||
self.value = value | |||
def __hash__(self) -> int: | |||
return hash(str(self.value)) | |||
def __eq__(self, o: object) -> bool: | |||
if not isinstance(o, _Hashable): | |||
return False | |||
return self.value == o.value | |||
def _matmul( | |||
inp1, | |||
inp2, | |||
transpose_a=False, | |||
transpose_b=False, | |||
compute_mode="default", | |||
format="default", | |||
): | |||
if amp._enabled: | |||
compute_mode = "float32" | |||
inp1, inp2 = cast_tensors(inp1, inp2) | |||
else: | |||
compute_mode = "default" | |||
dtype = dtype_promotion(inp1, inp2) | |||
if inp1.dtype != dtype: | |||
inp1 = inp1.astype(dtype) | |||
if inp2.dtype != dtype: | |||
inp2 = inp2.astype(dtype) | |||
dim1, dim2 = inp1.ndim, inp2.ndim | |||
assert dim1 > 0 and dim2 > 0 | |||
maxdim = dim1 if dim1 > dim2 else dim2 | |||
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
op = builtin.MatrixMul( | |||
transposeA=False, transposeB=False, compute_mode=compute_mode, format="default" | |||
) | |||
(result,) = apply(op, inp1, inp2) | |||
return result | |||
Strategy = builtin.ops.MatrixMul.Strategy | |||
strategy = Strategy(0) | |||
if _config._benchmark_kernel: | |||
strategy |= Strategy.PROFILE | |||
else: | |||
strategy |= Strategy.HEURISTIC | |||
if _config._deterministic_kernel: | |||
strategy |= Strategy.REPRODUCIBLE | |||
if dim1 == 1 and dim2 == 1: # dispatch to Dot | |||
(result,) = apply(builtin.Dot(), inp1, inp2) | |||
return result | |||
elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul | |||
extentedMatrixMulOp = _get_extentedMatrixMulOp( | |||
inp1.device, | |||
inp1.dtype, | |||
dim1, | |||
dim2, | |||
transpose_a, | |||
transpose_b, | |||
compute_mode, | |||
format, | |||
strategy=_Hashable(strategy), | |||
) | |||
(result,) = apply(extentedMatrixMulOp(), inp1, inp2) | |||
return result | |||
else: # dispath to BatchedMatrixMul | |||
extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp( | |||
inp1.device, | |||
inp1.dtype, | |||
dim1, | |||
dim2, | |||
transpose_a, | |||
transpose_b, | |||
compute_mode, | |||
format, | |||
strategy=_Hashable(strategy), | |||
) | |||
(result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2) | |||
return result | |||
def _transpose(data, axes): | |||
@@ -8,24 +8,18 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
import math | |||
from functools import lru_cache | |||
from typing import Iterable, Optional, Sequence, Tuple, Union | |||
from ..core import _config | |||
from ..core._imperative_rt.core2 import apply, dtype_promotion | |||
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder | |||
from ..core._trace_option import use_symbolic_shape | |||
from ..core.ops import builtin | |||
from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt | |||
from ..core.ops.special import Const | |||
from ..core.tensor import amp | |||
from ..core.tensor.utils import _normalize_axis, cast_tensors, subgraph | |||
from ..jit import exclude_from_trace | |||
from ..core.tensor.array_method import _matmul | |||
from ..core.tensor.utils import _normalize_axis | |||
from ..tensor import Tensor | |||
from ..utils.deprecation import deprecated_kwargs_default | |||
from .debug_param import get_execution_strategy | |||
from .elemwise import clip, minimum | |||
from .tensor import broadcast_to, concat, expand_dims, squeeze | |||
from .elemwise import clip | |||
from .tensor import expand_dims, squeeze | |||
__all__ = [ | |||
"argmax", | |||
@@ -794,229 +788,6 @@ def matinv(inp: Tensor) -> Tensor: | |||
return result | |||
class _Hashable: | |||
def __init__(self, value) -> None: | |||
self.value = value | |||
def __hash__(self) -> int: | |||
return hash(str(self.value)) | |||
def __eq__(self, o: object) -> bool: | |||
if not isinstance(o, _Hashable): | |||
return False | |||
return self.value == o.value | |||
@lru_cache(maxsize=None) | |||
def _get_extentedMatrixMulOp( | |||
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, | |||
): | |||
@subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=2) | |||
def extentedMatrixMulOp(inputs, f, c): | |||
assert len(inputs) == 2 | |||
inp1, inp2 = inputs | |||
_dim1, _dim2 = dim1, dim2 | |||
def build_shape_head(shape, idx=-1): | |||
# shape[:idx] | |||
return f( | |||
builtin.Subtensor(items=[[0, False, True, False, False]]), | |||
shape, | |||
c(idx, "int32"), | |||
) | |||
def build_shape_tail(shape, idx=-1): | |||
# shape[idx:] | |||
return f( | |||
builtin.Subtensor(items=[[0, True, False, False, False]]), | |||
shape, | |||
c(idx, "int32"), | |||
) | |||
remove_row, remove_col = False, False | |||
if _dim1 == 1: | |||
_dim1 = 2 | |||
remove_row = True | |||
if _dim2 == 1: | |||
_dim2 = 2 | |||
remove_col = True | |||
if remove_row: | |||
inp1 = f(builtin.AddAxis(axis=[0,]), inp1) | |||
if remove_col: | |||
inp2 = f(builtin.AddAxis(axis=[1,]), inp2) | |||
shape1 = f(GetVarShape(), inp1) | |||
shape2 = f(GetVarShape(), inp2) | |||
if _dim1 > 2: | |||
inp1 = f( | |||
builtin.Reshape(), | |||
inp1, | |||
f( | |||
builtin.Concat(axis=0, comp_node=device), | |||
f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape1)), | |||
build_shape_tail(shape1), | |||
), | |||
) | |||
if _dim2 > 2: | |||
inp2 = f( | |||
builtin.Reshape(), | |||
inp2, | |||
f( | |||
builtin.Concat(axis=0, comp_node=device), | |||
f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape2)), | |||
build_shape_tail(shape2), | |||
), | |||
) | |||
op = builtin.MatrixMul( | |||
transposeA=transpose_a, | |||
transposeB=transpose_b, | |||
compute_mode=compute_mode, | |||
format=format, | |||
strategy=strategy.value, | |||
) | |||
result = f(op, inp1, inp2) | |||
result_shape = f(GetVarShape(), result) | |||
if _dim1 > 2: | |||
result = f( | |||
builtin.Reshape(), | |||
result, | |||
f( | |||
builtin.Concat(axis=0, comp_node=device), | |||
build_shape_head(shape1), | |||
build_shape_tail(result_shape), | |||
), | |||
) | |||
if _dim2 > 2: | |||
result = f( | |||
builtin.Reshape(), | |||
result, | |||
f( | |||
builtin.Concat(axis=0, comp_node=device), | |||
build_shape_head(shape2), | |||
build_shape_tail(result_shape), | |||
), | |||
) | |||
maxdim = _dim1 if _dim1 > _dim2 else _dim2 | |||
if remove_row: | |||
result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) | |||
if remove_col: | |||
result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) | |||
return (result,), (True,) | |||
return extentedMatrixMulOp | |||
@lru_cache(maxsize=None) | |||
def _get_extentedBatchedMatrixMulOp( | |||
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, | |||
): | |||
@subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=2) | |||
def extentedBatchedMatrixMulOp(inputs, f, c): | |||
assert len(inputs) == 2 | |||
inp1, inp2 = inputs | |||
_dim1, _dim2 = dim1, dim2 | |||
def build_shape_head(shape, idx=-2): | |||
# shape[:idx] | |||
return f( | |||
builtin.Subtensor(items=[[0, False, True, False, False]]), | |||
shape, | |||
c(idx, "int32"), | |||
) | |||
def build_shape_tail(shape, idx=-2): | |||
# shape[idx:] | |||
return f( | |||
builtin.Subtensor(items=[[0, True, False, False, False]]), | |||
shape, | |||
c(idx, "int32"), | |||
) | |||
remove_row, remove_col = False, False | |||
if _dim1 == 1: | |||
_dim1 = 2 | |||
remove_row = True | |||
if _dim2 == 1: | |||
_dim2 = 2 | |||
remove_col = True | |||
if remove_row: | |||
inp1 = f(builtin.AddAxis(axis=[0,]), inp1) | |||
if remove_col: | |||
inp2 = f(builtin.AddAxis(axis=[1,]), inp2) | |||
shape1 = f(GetVarShape(), inp1) | |||
shape2 = f(GetVarShape(), inp2) | |||
maxdim = _dim1 if _dim1 > _dim2 else _dim2 | |||
if _dim1 > _dim2: | |||
# broadcast | |||
shape2 = f( | |||
builtin.Concat(axis=0, comp_node=device), | |||
build_shape_head(shape1, idx=-_dim2), # shape1[:-_dim2] | |||
shape2, | |||
) | |||
inp2 = f(builtin.Broadcast(), inp2, shape2) | |||
batch_shape = build_shape_head(shape1) | |||
if _dim2 > _dim1: | |||
# broadcast | |||
shape1 = f( | |||
builtin.Concat(axis=0, comp_node=device), | |||
build_shape_head(shape2, idx=-_dim1), # shape2[:-_dim1] | |||
shape1, | |||
) | |||
inp1 = f(builtin.Broadcast(), inp1, shape1) | |||
batch_shape = build_shape_head(shape2) | |||
if _dim1 == _dim2: | |||
batch_shape = build_shape_head(shape1) | |||
# compress inputs to 3d | |||
if maxdim > 3: | |||
inp1 = f( | |||
builtin.Reshape(), | |||
inp1, | |||
f( | |||
builtin.Concat(axis=0, comp_node=device), | |||
f(builtin.Reduce(mode="product", axis=0), batch_shape), | |||
build_shape_tail(shape1), | |||
), | |||
) | |||
inp2 = f( | |||
builtin.Reshape(), | |||
inp2, | |||
f( | |||
builtin.Concat(axis=0, comp_node=device), | |||
f(builtin.Reduce(mode="product", axis=0), batch_shape), | |||
build_shape_tail(shape2), | |||
), | |||
) | |||
op = builtin.BatchedMatrixMul( | |||
transposeA=transpose_a, | |||
transposeB=transpose_b, | |||
compute_mode=compute_mode, | |||
format=format, | |||
strategy=strategy.value, | |||
) | |||
result = f(op, inp1, inp2) | |||
if maxdim > 3: | |||
result = f( | |||
builtin.Reshape(), | |||
result, | |||
f( | |||
builtin.Concat(axis=0, comp_node=device), | |||
batch_shape, | |||
build_shape_tail(f(GetVarShape(), result)), | |||
), | |||
) | |||
if remove_row: | |||
result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) | |||
if remove_col: | |||
result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) | |||
return (result,), (True,) | |||
return extentedBatchedMatrixMulOp | |||
def matmul( | |||
inp1: Tensor, | |||
inp2: Tensor, | |||
@@ -1067,50 +838,7 @@ def matmul( | |||
[[10. 13.] | |||
[28. 40.]] | |||
""" | |||
if amp._enabled: | |||
compute_mode = "float32" | |||
inp1, inp2 = cast_tensors(inp1, inp2) | |||
else: | |||
dtype = dtype_promotion(inp1, inp2) | |||
if inp1.dtype != dtype: | |||
inp1 = inp1.astype(dtype) | |||
if inp2.dtype != dtype: | |||
inp2 = inp2.astype(dtype) | |||
dim1, dim2 = inp1.ndim, inp2.ndim | |||
assert dim1 > 0 and dim2 > 0 | |||
maxdim = dim1 if dim1 > dim2 else dim2 | |||
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
if dim1 == 1 and dim2 == 1: # dispatch to Dot | |||
return dot(inp1, inp2) | |||
elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul | |||
extentedMatrixMulOp = _get_extentedMatrixMulOp( | |||
inp1.device, | |||
inp1.dtype, | |||
dim1, | |||
dim2, | |||
transpose_a, | |||
transpose_b, | |||
compute_mode, | |||
format, | |||
strategy=_Hashable(get_execution_strategy()), | |||
) | |||
(result,) = apply(extentedMatrixMulOp(), inp1, inp2) | |||
return result | |||
else: # dispath to BatchedMatrixMul | |||
extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp( | |||
inp1.device, | |||
inp1.dtype, | |||
dim1, | |||
dim2, | |||
transpose_a, | |||
transpose_b, | |||
compute_mode, | |||
format, | |||
strategy=_Hashable(get_execution_strategy()), | |||
) | |||
(result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2) | |||
return result | |||
return _matmul(inp1, inp2, transpose_a, transpose_b, compute_mode, format) | |||
def dot(inp1: Tensor, inp2: Tensor) -> Tensor: | |||
@@ -46,14 +46,17 @@ def test_literal_arith(is_varnode): | |||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||
def test_matmul(is_varnode): | |||
@pytest.mark.parametrize( | |||
"shape_a, shape_b", [((4,), (4,)), ((10, 4), (4, 10)), ((3, 10, 4), (3, 4, 10)),], | |||
) | |||
def test_matmul(is_varnode, shape_a, shape_b): | |||
if is_varnode: | |||
network = Network() | |||
else: | |||
network = None | |||
A = make_tensor(np.random.rand(5, 7).astype("float32"), network) | |||
B = make_tensor(np.random.rand(7, 10).astype("float32"), network) | |||
A = make_tensor(np.random.rand(*shape_a).astype("float32"), network) | |||
B = make_tensor(np.random.rand(*shape_b).astype("float32"), network) | |||
C = A @ B | |||
if is_varnode: | |||
np.testing.assert_almost_equal( | |||