GitOrigin-RevId: bf6136cc1a
tags/v1.9.0
@@ -8,6 +8,7 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import abc | import abc | ||||
import collections | import collections | ||||
from functools import lru_cache | |||||
from typing import Union | from typing import Union | ||||
import numpy as np | import numpy as np | ||||
@@ -24,8 +25,8 @@ from .utils import ( | |||||
astype, | astype, | ||||
cast_tensors, | cast_tensors, | ||||
convert_inputs, | convert_inputs, | ||||
isscalar, | |||||
make_shape_tuple, | make_shape_tuple, | ||||
subgraph, | |||||
) | ) | ||||
_ElwMod = builtin.Elemwise.Mode | _ElwMod = builtin.Elemwise.Mode | ||||
@@ -73,23 +74,292 @@ def _elwise(*args, mode): | |||||
return _elwise_apply(args, mode) | return _elwise_apply(args, mode) | ||||
def _matmul(inp1, inp2): | |||||
@lru_cache(maxsize=None) | |||||
def _get_extentedMatrixMulOp( | |||||
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, | |||||
): | |||||
@subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=2) | |||||
def extentedMatrixMulOp(inputs, f, c): | |||||
assert len(inputs) == 2 | |||||
inp1, inp2 = inputs | |||||
_dim1, _dim2 = dim1, dim2 | |||||
def build_shape_head(shape, idx=-1): | |||||
# shape[:idx] | |||||
return f( | |||||
builtin.Subtensor(items=[[0, False, True, False, False]]), | |||||
shape, | |||||
c(idx, "int32"), | |||||
) | |||||
def build_shape_tail(shape, idx=-1): | |||||
# shape[idx:] | |||||
return f( | |||||
builtin.Subtensor(items=[[0, True, False, False, False]]), | |||||
shape, | |||||
c(idx, "int32"), | |||||
) | |||||
remove_row, remove_col = False, False | |||||
if _dim1 == 1: | |||||
_dim1 = 2 | |||||
remove_row = True | |||||
if _dim2 == 1: | |||||
_dim2 = 2 | |||||
remove_col = True | |||||
if remove_row: | |||||
inp1 = f(builtin.AddAxis(axis=[0,]), inp1) | |||||
if remove_col: | |||||
inp2 = f(builtin.AddAxis(axis=[1,]), inp2) | |||||
shape1 = f(builtin.GetVarShape(), inp1) | |||||
shape2 = f(builtin.GetVarShape(), inp2) | |||||
if _dim1 > 2: | |||||
inp1 = f( | |||||
builtin.Reshape(), | |||||
inp1, | |||||
f( | |||||
builtin.Concat(axis=0, comp_node=device), | |||||
f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape1)), | |||||
build_shape_tail(shape1), | |||||
), | |||||
) | |||||
if _dim2 > 2: | |||||
inp2 = f( | |||||
builtin.Reshape(), | |||||
inp2, | |||||
f( | |||||
builtin.Concat(axis=0, comp_node=device), | |||||
f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape2)), | |||||
build_shape_tail(shape2), | |||||
), | |||||
) | |||||
op = builtin.MatrixMul( | |||||
transposeA=transpose_a, | |||||
transposeB=transpose_b, | |||||
compute_mode=compute_mode, | |||||
format=format, | |||||
strategy=strategy.value, | |||||
) | |||||
result = f(op, inp1, inp2) | |||||
result_shape = f(builtin.GetVarShape(), result) | |||||
if _dim1 > 2: | |||||
result = f( | |||||
builtin.Reshape(), | |||||
result, | |||||
f( | |||||
builtin.Concat(axis=0, comp_node=device), | |||||
build_shape_head(shape1), | |||||
build_shape_tail(result_shape), | |||||
), | |||||
) | |||||
if _dim2 > 2: | |||||
result = f( | |||||
builtin.Reshape(), | |||||
result, | |||||
f( | |||||
builtin.Concat(axis=0, comp_node=device), | |||||
build_shape_head(shape2), | |||||
build_shape_tail(result_shape), | |||||
), | |||||
) | |||||
maxdim = _dim1 if _dim1 > _dim2 else _dim2 | |||||
if remove_row: | |||||
result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) | |||||
if remove_col: | |||||
result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) | |||||
return (result,), (True,) | |||||
return extentedMatrixMulOp | |||||
@lru_cache(maxsize=None) | |||||
def _get_extentedBatchedMatrixMulOp( | |||||
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, | |||||
): | |||||
@subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=2) | |||||
def extentedBatchedMatrixMulOp(inputs, f, c): | |||||
assert len(inputs) == 2 | |||||
inp1, inp2 = inputs | |||||
_dim1, _dim2 = dim1, dim2 | |||||
def build_shape_head(shape, idx=-2): | |||||
# shape[:idx] | |||||
return f( | |||||
builtin.Subtensor(items=[[0, False, True, False, False]]), | |||||
shape, | |||||
c(idx, "int32"), | |||||
) | |||||
def build_shape_tail(shape, idx=-2): | |||||
# shape[idx:] | |||||
return f( | |||||
builtin.Subtensor(items=[[0, True, False, False, False]]), | |||||
shape, | |||||
c(idx, "int32"), | |||||
) | |||||
remove_row, remove_col = False, False | |||||
if _dim1 == 1: | |||||
_dim1 = 2 | |||||
remove_row = True | |||||
if _dim2 == 1: | |||||
_dim2 = 2 | |||||
remove_col = True | |||||
if remove_row: | |||||
inp1 = f(builtin.AddAxis(axis=[0,]), inp1) | |||||
if remove_col: | |||||
inp2 = f(builtin.AddAxis(axis=[1,]), inp2) | |||||
shape1 = f(builtin.GetVarShape(), inp1) | |||||
shape2 = f(builtin.GetVarShape(), inp2) | |||||
maxdim = _dim1 if _dim1 > _dim2 else _dim2 | |||||
if _dim1 > _dim2: | |||||
# broadcast | |||||
shape2 = f( | |||||
builtin.Concat(axis=0, comp_node=device), | |||||
build_shape_head(shape1, idx=-_dim2), # shape1[:-_dim2] | |||||
shape2, | |||||
) | |||||
inp2 = f(builtin.Broadcast(), inp2, shape2) | |||||
batch_shape = build_shape_head(shape1) | |||||
if _dim2 > _dim1: | |||||
# broadcast | |||||
shape1 = f( | |||||
builtin.Concat(axis=0, comp_node=device), | |||||
build_shape_head(shape2, idx=-_dim1), # shape2[:-_dim1] | |||||
shape1, | |||||
) | |||||
inp1 = f(builtin.Broadcast(), inp1, shape1) | |||||
batch_shape = build_shape_head(shape2) | |||||
if _dim1 == _dim2: | |||||
batch_shape = build_shape_head(shape1) | |||||
# compress inputs to 3d | |||||
if maxdim > 3: | |||||
inp1 = f( | |||||
builtin.Reshape(), | |||||
inp1, | |||||
f( | |||||
builtin.Concat(axis=0, comp_node=device), | |||||
f(builtin.Reduce(mode="product", axis=0), batch_shape), | |||||
build_shape_tail(shape1), | |||||
), | |||||
) | |||||
inp2 = f( | |||||
builtin.Reshape(), | |||||
inp2, | |||||
f( | |||||
builtin.Concat(axis=0, comp_node=device), | |||||
f(builtin.Reduce(mode="product", axis=0), batch_shape), | |||||
build_shape_tail(shape2), | |||||
), | |||||
) | |||||
op = builtin.BatchedMatrixMul( | |||||
transposeA=transpose_a, | |||||
transposeB=transpose_b, | |||||
compute_mode=compute_mode, | |||||
format=format, | |||||
strategy=strategy.value, | |||||
) | |||||
result = f(op, inp1, inp2) | |||||
if maxdim > 3: | |||||
result = f( | |||||
builtin.Reshape(), | |||||
result, | |||||
f( | |||||
builtin.Concat(axis=0, comp_node=device), | |||||
batch_shape, | |||||
build_shape_tail(f(builtin.GetVarShape(), result)), | |||||
), | |||||
) | |||||
if remove_row: | |||||
result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) | |||||
if remove_col: | |||||
result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) | |||||
return (result,), (True,) | |||||
return extentedBatchedMatrixMulOp | |||||
class _Hashable: | |||||
def __init__(self, value) -> None: | |||||
self.value = value | |||||
def __hash__(self) -> int: | |||||
return hash(str(self.value)) | |||||
def __eq__(self, o: object) -> bool: | |||||
if not isinstance(o, _Hashable): | |||||
return False | |||||
return self.value == o.value | |||||
def _matmul( | |||||
inp1, | |||||
inp2, | |||||
transpose_a=False, | |||||
transpose_b=False, | |||||
compute_mode="default", | |||||
format="default", | |||||
): | |||||
if amp._enabled: | if amp._enabled: | ||||
compute_mode = "float32" | compute_mode = "float32" | ||||
inp1, inp2 = cast_tensors(inp1, inp2) | inp1, inp2 = cast_tensors(inp1, inp2) | ||||
else: | else: | ||||
compute_mode = "default" | |||||
dtype = dtype_promotion(inp1, inp2) | dtype = dtype_promotion(inp1, inp2) | ||||
if inp1.dtype != dtype: | if inp1.dtype != dtype: | ||||
inp1 = inp1.astype(dtype) | inp1 = inp1.astype(dtype) | ||||
if inp2.dtype != dtype: | if inp2.dtype != dtype: | ||||
inp2 = inp2.astype(dtype) | inp2 = inp2.astype(dtype) | ||||
dim1, dim2 = inp1.ndim, inp2.ndim | |||||
assert dim1 > 0 and dim2 > 0 | |||||
maxdim = dim1 if dim1 > dim2 else dim2 | |||||
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | ||||
op = builtin.MatrixMul( | |||||
transposeA=False, transposeB=False, compute_mode=compute_mode, format="default" | |||||
) | |||||
(result,) = apply(op, inp1, inp2) | |||||
return result | |||||
Strategy = builtin.ops.MatrixMul.Strategy | |||||
strategy = Strategy(0) | |||||
if _config._benchmark_kernel: | |||||
strategy |= Strategy.PROFILE | |||||
else: | |||||
strategy |= Strategy.HEURISTIC | |||||
if _config._deterministic_kernel: | |||||
strategy |= Strategy.REPRODUCIBLE | |||||
if dim1 == 1 and dim2 == 1: # dispatch to Dot | |||||
(result,) = apply(builtin.Dot(), inp1, inp2) | |||||
return result | |||||
elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul | |||||
extentedMatrixMulOp = _get_extentedMatrixMulOp( | |||||
inp1.device, | |||||
inp1.dtype, | |||||
dim1, | |||||
dim2, | |||||
transpose_a, | |||||
transpose_b, | |||||
compute_mode, | |||||
format, | |||||
strategy=_Hashable(strategy), | |||||
) | |||||
(result,) = apply(extentedMatrixMulOp(), inp1, inp2) | |||||
return result | |||||
else: # dispath to BatchedMatrixMul | |||||
extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp( | |||||
inp1.device, | |||||
inp1.dtype, | |||||
dim1, | |||||
dim2, | |||||
transpose_a, | |||||
transpose_b, | |||||
compute_mode, | |||||
format, | |||||
strategy=_Hashable(strategy), | |||||
) | |||||
(result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2) | |||||
return result | |||||
def _transpose(data, axes): | def _transpose(data, axes): | ||||
@@ -8,24 +8,18 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import collections | import collections | ||||
import math | import math | ||||
from functools import lru_cache | |||||
from typing import Iterable, Optional, Sequence, Tuple, Union | from typing import Iterable, Optional, Sequence, Tuple, Union | ||||
from ..core import _config | |||||
from ..core._imperative_rt.core2 import apply, dtype_promotion | from ..core._imperative_rt.core2 import apply, dtype_promotion | ||||
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder | from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder | ||||
from ..core._trace_option import use_symbolic_shape | |||||
from ..core.ops import builtin | from ..core.ops import builtin | ||||
from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt | |||||
from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
from ..core.tensor import amp | |||||
from ..core.tensor.utils import _normalize_axis, cast_tensors, subgraph | |||||
from ..jit import exclude_from_trace | |||||
from ..core.tensor.array_method import _matmul | |||||
from ..core.tensor.utils import _normalize_axis | |||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
from ..utils.deprecation import deprecated_kwargs_default | from ..utils.deprecation import deprecated_kwargs_default | ||||
from .debug_param import get_execution_strategy | |||||
from .elemwise import clip, minimum | |||||
from .tensor import broadcast_to, concat, expand_dims, squeeze | |||||
from .elemwise import clip | |||||
from .tensor import expand_dims, squeeze | |||||
__all__ = [ | __all__ = [ | ||||
"argmax", | "argmax", | ||||
@@ -794,229 +788,6 @@ def matinv(inp: Tensor) -> Tensor: | |||||
return result | return result | ||||
class _Hashable: | |||||
def __init__(self, value) -> None: | |||||
self.value = value | |||||
def __hash__(self) -> int: | |||||
return hash(str(self.value)) | |||||
def __eq__(self, o: object) -> bool: | |||||
if not isinstance(o, _Hashable): | |||||
return False | |||||
return self.value == o.value | |||||
@lru_cache(maxsize=None) | |||||
def _get_extentedMatrixMulOp( | |||||
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, | |||||
): | |||||
@subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=2) | |||||
def extentedMatrixMulOp(inputs, f, c): | |||||
assert len(inputs) == 2 | |||||
inp1, inp2 = inputs | |||||
_dim1, _dim2 = dim1, dim2 | |||||
def build_shape_head(shape, idx=-1): | |||||
# shape[:idx] | |||||
return f( | |||||
builtin.Subtensor(items=[[0, False, True, False, False]]), | |||||
shape, | |||||
c(idx, "int32"), | |||||
) | |||||
def build_shape_tail(shape, idx=-1): | |||||
# shape[idx:] | |||||
return f( | |||||
builtin.Subtensor(items=[[0, True, False, False, False]]), | |||||
shape, | |||||
c(idx, "int32"), | |||||
) | |||||
remove_row, remove_col = False, False | |||||
if _dim1 == 1: | |||||
_dim1 = 2 | |||||
remove_row = True | |||||
if _dim2 == 1: | |||||
_dim2 = 2 | |||||
remove_col = True | |||||
if remove_row: | |||||
inp1 = f(builtin.AddAxis(axis=[0,]), inp1) | |||||
if remove_col: | |||||
inp2 = f(builtin.AddAxis(axis=[1,]), inp2) | |||||
shape1 = f(GetVarShape(), inp1) | |||||
shape2 = f(GetVarShape(), inp2) | |||||
if _dim1 > 2: | |||||
inp1 = f( | |||||
builtin.Reshape(), | |||||
inp1, | |||||
f( | |||||
builtin.Concat(axis=0, comp_node=device), | |||||
f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape1)), | |||||
build_shape_tail(shape1), | |||||
), | |||||
) | |||||
if _dim2 > 2: | |||||
inp2 = f( | |||||
builtin.Reshape(), | |||||
inp2, | |||||
f( | |||||
builtin.Concat(axis=0, comp_node=device), | |||||
f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape2)), | |||||
build_shape_tail(shape2), | |||||
), | |||||
) | |||||
op = builtin.MatrixMul( | |||||
transposeA=transpose_a, | |||||
transposeB=transpose_b, | |||||
compute_mode=compute_mode, | |||||
format=format, | |||||
strategy=strategy.value, | |||||
) | |||||
result = f(op, inp1, inp2) | |||||
result_shape = f(GetVarShape(), result) | |||||
if _dim1 > 2: | |||||
result = f( | |||||
builtin.Reshape(), | |||||
result, | |||||
f( | |||||
builtin.Concat(axis=0, comp_node=device), | |||||
build_shape_head(shape1), | |||||
build_shape_tail(result_shape), | |||||
), | |||||
) | |||||
if _dim2 > 2: | |||||
result = f( | |||||
builtin.Reshape(), | |||||
result, | |||||
f( | |||||
builtin.Concat(axis=0, comp_node=device), | |||||
build_shape_head(shape2), | |||||
build_shape_tail(result_shape), | |||||
), | |||||
) | |||||
maxdim = _dim1 if _dim1 > _dim2 else _dim2 | |||||
if remove_row: | |||||
result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) | |||||
if remove_col: | |||||
result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) | |||||
return (result,), (True,) | |||||
return extentedMatrixMulOp | |||||
@lru_cache(maxsize=None) | |||||
def _get_extentedBatchedMatrixMulOp( | |||||
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, | |||||
): | |||||
@subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=2) | |||||
def extentedBatchedMatrixMulOp(inputs, f, c): | |||||
assert len(inputs) == 2 | |||||
inp1, inp2 = inputs | |||||
_dim1, _dim2 = dim1, dim2 | |||||
def build_shape_head(shape, idx=-2): | |||||
# shape[:idx] | |||||
return f( | |||||
builtin.Subtensor(items=[[0, False, True, False, False]]), | |||||
shape, | |||||
c(idx, "int32"), | |||||
) | |||||
def build_shape_tail(shape, idx=-2): | |||||
# shape[idx:] | |||||
return f( | |||||
builtin.Subtensor(items=[[0, True, False, False, False]]), | |||||
shape, | |||||
c(idx, "int32"), | |||||
) | |||||
remove_row, remove_col = False, False | |||||
if _dim1 == 1: | |||||
_dim1 = 2 | |||||
remove_row = True | |||||
if _dim2 == 1: | |||||
_dim2 = 2 | |||||
remove_col = True | |||||
if remove_row: | |||||
inp1 = f(builtin.AddAxis(axis=[0,]), inp1) | |||||
if remove_col: | |||||
inp2 = f(builtin.AddAxis(axis=[1,]), inp2) | |||||
shape1 = f(GetVarShape(), inp1) | |||||
shape2 = f(GetVarShape(), inp2) | |||||
maxdim = _dim1 if _dim1 > _dim2 else _dim2 | |||||
if _dim1 > _dim2: | |||||
# broadcast | |||||
shape2 = f( | |||||
builtin.Concat(axis=0, comp_node=device), | |||||
build_shape_head(shape1, idx=-_dim2), # shape1[:-_dim2] | |||||
shape2, | |||||
) | |||||
inp2 = f(builtin.Broadcast(), inp2, shape2) | |||||
batch_shape = build_shape_head(shape1) | |||||
if _dim2 > _dim1: | |||||
# broadcast | |||||
shape1 = f( | |||||
builtin.Concat(axis=0, comp_node=device), | |||||
build_shape_head(shape2, idx=-_dim1), # shape2[:-_dim1] | |||||
shape1, | |||||
) | |||||
inp1 = f(builtin.Broadcast(), inp1, shape1) | |||||
batch_shape = build_shape_head(shape2) | |||||
if _dim1 == _dim2: | |||||
batch_shape = build_shape_head(shape1) | |||||
# compress inputs to 3d | |||||
if maxdim > 3: | |||||
inp1 = f( | |||||
builtin.Reshape(), | |||||
inp1, | |||||
f( | |||||
builtin.Concat(axis=0, comp_node=device), | |||||
f(builtin.Reduce(mode="product", axis=0), batch_shape), | |||||
build_shape_tail(shape1), | |||||
), | |||||
) | |||||
inp2 = f( | |||||
builtin.Reshape(), | |||||
inp2, | |||||
f( | |||||
builtin.Concat(axis=0, comp_node=device), | |||||
f(builtin.Reduce(mode="product", axis=0), batch_shape), | |||||
build_shape_tail(shape2), | |||||
), | |||||
) | |||||
op = builtin.BatchedMatrixMul( | |||||
transposeA=transpose_a, | |||||
transposeB=transpose_b, | |||||
compute_mode=compute_mode, | |||||
format=format, | |||||
strategy=strategy.value, | |||||
) | |||||
result = f(op, inp1, inp2) | |||||
if maxdim > 3: | |||||
result = f( | |||||
builtin.Reshape(), | |||||
result, | |||||
f( | |||||
builtin.Concat(axis=0, comp_node=device), | |||||
batch_shape, | |||||
build_shape_tail(f(GetVarShape(), result)), | |||||
), | |||||
) | |||||
if remove_row: | |||||
result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) | |||||
if remove_col: | |||||
result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) | |||||
return (result,), (True,) | |||||
return extentedBatchedMatrixMulOp | |||||
def matmul( | def matmul( | ||||
inp1: Tensor, | inp1: Tensor, | ||||
inp2: Tensor, | inp2: Tensor, | ||||
@@ -1067,50 +838,7 @@ def matmul( | |||||
[[10. 13.] | [[10. 13.] | ||||
[28. 40.]] | [28. 40.]] | ||||
""" | """ | ||||
if amp._enabled: | |||||
compute_mode = "float32" | |||||
inp1, inp2 = cast_tensors(inp1, inp2) | |||||
else: | |||||
dtype = dtype_promotion(inp1, inp2) | |||||
if inp1.dtype != dtype: | |||||
inp1 = inp1.astype(dtype) | |||||
if inp2.dtype != dtype: | |||||
inp2 = inp2.astype(dtype) | |||||
dim1, dim2 = inp1.ndim, inp2.ndim | |||||
assert dim1 > 0 and dim2 > 0 | |||||
maxdim = dim1 if dim1 > dim2 else dim2 | |||||
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||||
if dim1 == 1 and dim2 == 1: # dispatch to Dot | |||||
return dot(inp1, inp2) | |||||
elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul | |||||
extentedMatrixMulOp = _get_extentedMatrixMulOp( | |||||
inp1.device, | |||||
inp1.dtype, | |||||
dim1, | |||||
dim2, | |||||
transpose_a, | |||||
transpose_b, | |||||
compute_mode, | |||||
format, | |||||
strategy=_Hashable(get_execution_strategy()), | |||||
) | |||||
(result,) = apply(extentedMatrixMulOp(), inp1, inp2) | |||||
return result | |||||
else: # dispath to BatchedMatrixMul | |||||
extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp( | |||||
inp1.device, | |||||
inp1.dtype, | |||||
dim1, | |||||
dim2, | |||||
transpose_a, | |||||
transpose_b, | |||||
compute_mode, | |||||
format, | |||||
strategy=_Hashable(get_execution_strategy()), | |||||
) | |||||
(result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2) | |||||
return result | |||||
return _matmul(inp1, inp2, transpose_a, transpose_b, compute_mode, format) | |||||
def dot(inp1: Tensor, inp2: Tensor) -> Tensor: | def dot(inp1: Tensor, inp2: Tensor) -> Tensor: | ||||
@@ -46,14 +46,17 @@ def test_literal_arith(is_varnode): | |||||
@pytest.mark.parametrize("is_varnode", [True, False]) | @pytest.mark.parametrize("is_varnode", [True, False]) | ||||
def test_matmul(is_varnode): | |||||
@pytest.mark.parametrize( | |||||
"shape_a, shape_b", [((4,), (4,)), ((10, 4), (4, 10)), ((3, 10, 4), (3, 4, 10)),], | |||||
) | |||||
def test_matmul(is_varnode, shape_a, shape_b): | |||||
if is_varnode: | if is_varnode: | ||||
network = Network() | network = Network() | ||||
else: | else: | ||||
network = None | network = None | ||||
A = make_tensor(np.random.rand(5, 7).astype("float32"), network) | |||||
B = make_tensor(np.random.rand(7, 10).astype("float32"), network) | |||||
A = make_tensor(np.random.rand(*shape_a).astype("float32"), network) | |||||
B = make_tensor(np.random.rand(*shape_b).astype("float32"), network) | |||||
C = A @ B | C = A @ B | ||||
if is_varnode: | if is_varnode: | ||||
np.testing.assert_almost_equal( | np.testing.assert_almost_equal( | ||||