Browse Source

fix(imperative/tensor): make @ operator has the same functionality as matmul functional

GitOrigin-RevId: bf6136cc1a
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
ca2deebc0f
3 changed files with 289 additions and 288 deletions
  1. +278
    -8
      imperative/python/megengine/core/tensor/array_method.py
  2. +5
    -277
      imperative/python/megengine/functional/math.py
  3. +6
    -3
      imperative/python/test/unit/core/test_tensor_wrapper.py

+ 278
- 8
imperative/python/megengine/core/tensor/array_method.py View File

@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import abc import abc
import collections import collections
from functools import lru_cache
from typing import Union from typing import Union


import numpy as np import numpy as np
@@ -24,8 +25,8 @@ from .utils import (
astype, astype,
cast_tensors, cast_tensors,
convert_inputs, convert_inputs,
isscalar,
make_shape_tuple, make_shape_tuple,
subgraph,
) )


_ElwMod = builtin.Elemwise.Mode _ElwMod = builtin.Elemwise.Mode
@@ -73,23 +74,292 @@ def _elwise(*args, mode):
return _elwise_apply(args, mode) return _elwise_apply(args, mode)




def _matmul(inp1, inp2):
@lru_cache(maxsize=None)
def _get_extentedMatrixMulOp(
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy,
):
@subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=2)
def extentedMatrixMulOp(inputs, f, c):
assert len(inputs) == 2
inp1, inp2 = inputs
_dim1, _dim2 = dim1, dim2

def build_shape_head(shape, idx=-1):
# shape[:idx]
return f(
builtin.Subtensor(items=[[0, False, True, False, False]]),
shape,
c(idx, "int32"),
)

def build_shape_tail(shape, idx=-1):
# shape[idx:]
return f(
builtin.Subtensor(items=[[0, True, False, False, False]]),
shape,
c(idx, "int32"),
)

remove_row, remove_col = False, False
if _dim1 == 1:
_dim1 = 2
remove_row = True
if _dim2 == 1:
_dim2 = 2
remove_col = True

if remove_row:
inp1 = f(builtin.AddAxis(axis=[0,]), inp1)
if remove_col:
inp2 = f(builtin.AddAxis(axis=[1,]), inp2)

shape1 = f(builtin.GetVarShape(), inp1)
shape2 = f(builtin.GetVarShape(), inp2)
if _dim1 > 2:
inp1 = f(
builtin.Reshape(),
inp1,
f(
builtin.Concat(axis=0, comp_node=device),
f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape1)),
build_shape_tail(shape1),
),
)
if _dim2 > 2:
inp2 = f(
builtin.Reshape(),
inp2,
f(
builtin.Concat(axis=0, comp_node=device),
f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape2)),
build_shape_tail(shape2),
),
)
op = builtin.MatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
strategy=strategy.value,
)
result = f(op, inp1, inp2)
result_shape = f(builtin.GetVarShape(), result)
if _dim1 > 2:
result = f(
builtin.Reshape(),
result,
f(
builtin.Concat(axis=0, comp_node=device),
build_shape_head(shape1),
build_shape_tail(result_shape),
),
)
if _dim2 > 2:
result = f(
builtin.Reshape(),
result,
f(
builtin.Concat(axis=0, comp_node=device),
build_shape_head(shape2),
build_shape_tail(result_shape),
),
)
maxdim = _dim1 if _dim1 > _dim2 else _dim2
if remove_row:
result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result)
if remove_col:
result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result)
return (result,), (True,)

return extentedMatrixMulOp


@lru_cache(maxsize=None)
def _get_extentedBatchedMatrixMulOp(
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy,
):
@subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=2)
def extentedBatchedMatrixMulOp(inputs, f, c):
assert len(inputs) == 2
inp1, inp2 = inputs
_dim1, _dim2 = dim1, dim2

def build_shape_head(shape, idx=-2):
# shape[:idx]
return f(
builtin.Subtensor(items=[[0, False, True, False, False]]),
shape,
c(idx, "int32"),
)

def build_shape_tail(shape, idx=-2):
# shape[idx:]
return f(
builtin.Subtensor(items=[[0, True, False, False, False]]),
shape,
c(idx, "int32"),
)

remove_row, remove_col = False, False
if _dim1 == 1:
_dim1 = 2
remove_row = True
if _dim2 == 1:
_dim2 = 2
remove_col = True

if remove_row:
inp1 = f(builtin.AddAxis(axis=[0,]), inp1)
if remove_col:
inp2 = f(builtin.AddAxis(axis=[1,]), inp2)
shape1 = f(builtin.GetVarShape(), inp1)
shape2 = f(builtin.GetVarShape(), inp2)
maxdim = _dim1 if _dim1 > _dim2 else _dim2
if _dim1 > _dim2:
# broadcast
shape2 = f(
builtin.Concat(axis=0, comp_node=device),
build_shape_head(shape1, idx=-_dim2), # shape1[:-_dim2]
shape2,
)
inp2 = f(builtin.Broadcast(), inp2, shape2)
batch_shape = build_shape_head(shape1)
if _dim2 > _dim1:
# broadcast
shape1 = f(
builtin.Concat(axis=0, comp_node=device),
build_shape_head(shape2, idx=-_dim1), # shape2[:-_dim1]
shape1,
)
inp1 = f(builtin.Broadcast(), inp1, shape1)
batch_shape = build_shape_head(shape2)
if _dim1 == _dim2:
batch_shape = build_shape_head(shape1)

# compress inputs to 3d
if maxdim > 3:
inp1 = f(
builtin.Reshape(),
inp1,
f(
builtin.Concat(axis=0, comp_node=device),
f(builtin.Reduce(mode="product", axis=0), batch_shape),
build_shape_tail(shape1),
),
)
inp2 = f(
builtin.Reshape(),
inp2,
f(
builtin.Concat(axis=0, comp_node=device),
f(builtin.Reduce(mode="product", axis=0), batch_shape),
build_shape_tail(shape2),
),
)
op = builtin.BatchedMatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
strategy=strategy.value,
)
result = f(op, inp1, inp2)

if maxdim > 3:
result = f(
builtin.Reshape(),
result,
f(
builtin.Concat(axis=0, comp_node=device),
batch_shape,
build_shape_tail(f(builtin.GetVarShape(), result)),
),
)
if remove_row:
result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result)
if remove_col:
result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result)
return (result,), (True,)

return extentedBatchedMatrixMulOp


class _Hashable:
def __init__(self, value) -> None:
self.value = value

def __hash__(self) -> int:
return hash(str(self.value))

def __eq__(self, o: object) -> bool:
if not isinstance(o, _Hashable):
return False
return self.value == o.value


def _matmul(
inp1,
inp2,
transpose_a=False,
transpose_b=False,
compute_mode="default",
format="default",
):
if amp._enabled: if amp._enabled:
compute_mode = "float32" compute_mode = "float32"
inp1, inp2 = cast_tensors(inp1, inp2) inp1, inp2 = cast_tensors(inp1, inp2)
else: else:
compute_mode = "default"
dtype = dtype_promotion(inp1, inp2) dtype = dtype_promotion(inp1, inp2)
if inp1.dtype != dtype: if inp1.dtype != dtype:
inp1 = inp1.astype(dtype) inp1 = inp1.astype(dtype)
if inp2.dtype != dtype: if inp2.dtype != dtype:
inp2 = inp2.astype(dtype) inp2 = inp2.astype(dtype)

dim1, dim2 = inp1.ndim, inp2.ndim
assert dim1 > 0 and dim2 > 0
maxdim = dim1 if dim1 > dim2 else dim2
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
op = builtin.MatrixMul(
transposeA=False, transposeB=False, compute_mode=compute_mode, format="default"
)
(result,) = apply(op, inp1, inp2)
return result

Strategy = builtin.ops.MatrixMul.Strategy
strategy = Strategy(0)
if _config._benchmark_kernel:
strategy |= Strategy.PROFILE
else:
strategy |= Strategy.HEURISTIC
if _config._deterministic_kernel:
strategy |= Strategy.REPRODUCIBLE

if dim1 == 1 and dim2 == 1: # dispatch to Dot
(result,) = apply(builtin.Dot(), inp1, inp2)
return result
elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul
extentedMatrixMulOp = _get_extentedMatrixMulOp(
inp1.device,
inp1.dtype,
dim1,
dim2,
transpose_a,
transpose_b,
compute_mode,
format,
strategy=_Hashable(strategy),
)
(result,) = apply(extentedMatrixMulOp(), inp1, inp2)
return result
else: # dispath to BatchedMatrixMul
extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp(
inp1.device,
inp1.dtype,
dim1,
dim2,
transpose_a,
transpose_b,
compute_mode,
format,
strategy=_Hashable(strategy),
)
(result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2)
return result




def _transpose(data, axes): def _transpose(data, axes):


+ 5
- 277
imperative/python/megengine/functional/math.py View File

@@ -8,24 +8,18 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections import collections
import math import math
from functools import lru_cache
from typing import Iterable, Optional, Sequence, Tuple, Union from typing import Iterable, Optional, Sequence, Tuple, Union


from ..core import _config
from ..core._imperative_rt.core2 import apply, dtype_promotion from ..core._imperative_rt.core2 import apply, dtype_promotion
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from ..core._trace_option import use_symbolic_shape
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import amp
from ..core.tensor.utils import _normalize_axis, cast_tensors, subgraph
from ..jit import exclude_from_trace
from ..core.tensor.array_method import _matmul
from ..core.tensor.utils import _normalize_axis
from ..tensor import Tensor from ..tensor import Tensor
from ..utils.deprecation import deprecated_kwargs_default from ..utils.deprecation import deprecated_kwargs_default
from .debug_param import get_execution_strategy
from .elemwise import clip, minimum
from .tensor import broadcast_to, concat, expand_dims, squeeze
from .elemwise import clip
from .tensor import expand_dims, squeeze


__all__ = [ __all__ = [
"argmax", "argmax",
@@ -794,229 +788,6 @@ def matinv(inp: Tensor) -> Tensor:
return result return result




class _Hashable:
def __init__(self, value) -> None:
self.value = value

def __hash__(self) -> int:
return hash(str(self.value))

def __eq__(self, o: object) -> bool:
if not isinstance(o, _Hashable):
return False
return self.value == o.value


@lru_cache(maxsize=None)
def _get_extentedMatrixMulOp(
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy,
):
@subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=2)
def extentedMatrixMulOp(inputs, f, c):
assert len(inputs) == 2
inp1, inp2 = inputs
_dim1, _dim2 = dim1, dim2

def build_shape_head(shape, idx=-1):
# shape[:idx]
return f(
builtin.Subtensor(items=[[0, False, True, False, False]]),
shape,
c(idx, "int32"),
)

def build_shape_tail(shape, idx=-1):
# shape[idx:]
return f(
builtin.Subtensor(items=[[0, True, False, False, False]]),
shape,
c(idx, "int32"),
)

remove_row, remove_col = False, False
if _dim1 == 1:
_dim1 = 2
remove_row = True
if _dim2 == 1:
_dim2 = 2
remove_col = True

if remove_row:
inp1 = f(builtin.AddAxis(axis=[0,]), inp1)
if remove_col:
inp2 = f(builtin.AddAxis(axis=[1,]), inp2)

shape1 = f(GetVarShape(), inp1)
shape2 = f(GetVarShape(), inp2)
if _dim1 > 2:
inp1 = f(
builtin.Reshape(),
inp1,
f(
builtin.Concat(axis=0, comp_node=device),
f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape1)),
build_shape_tail(shape1),
),
)
if _dim2 > 2:
inp2 = f(
builtin.Reshape(),
inp2,
f(
builtin.Concat(axis=0, comp_node=device),
f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape2)),
build_shape_tail(shape2),
),
)
op = builtin.MatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
strategy=strategy.value,
)
result = f(op, inp1, inp2)
result_shape = f(GetVarShape(), result)
if _dim1 > 2:
result = f(
builtin.Reshape(),
result,
f(
builtin.Concat(axis=0, comp_node=device),
build_shape_head(shape1),
build_shape_tail(result_shape),
),
)
if _dim2 > 2:
result = f(
builtin.Reshape(),
result,
f(
builtin.Concat(axis=0, comp_node=device),
build_shape_head(shape2),
build_shape_tail(result_shape),
),
)
maxdim = _dim1 if _dim1 > _dim2 else _dim2
if remove_row:
result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result)
if remove_col:
result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result)
return (result,), (True,)

return extentedMatrixMulOp


@lru_cache(maxsize=None)
def _get_extentedBatchedMatrixMulOp(
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy,
):
@subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=2)
def extentedBatchedMatrixMulOp(inputs, f, c):
assert len(inputs) == 2
inp1, inp2 = inputs
_dim1, _dim2 = dim1, dim2

def build_shape_head(shape, idx=-2):
# shape[:idx]
return f(
builtin.Subtensor(items=[[0, False, True, False, False]]),
shape,
c(idx, "int32"),
)

def build_shape_tail(shape, idx=-2):
# shape[idx:]
return f(
builtin.Subtensor(items=[[0, True, False, False, False]]),
shape,
c(idx, "int32"),
)

remove_row, remove_col = False, False
if _dim1 == 1:
_dim1 = 2
remove_row = True
if _dim2 == 1:
_dim2 = 2
remove_col = True

if remove_row:
inp1 = f(builtin.AddAxis(axis=[0,]), inp1)
if remove_col:
inp2 = f(builtin.AddAxis(axis=[1,]), inp2)
shape1 = f(GetVarShape(), inp1)
shape2 = f(GetVarShape(), inp2)
maxdim = _dim1 if _dim1 > _dim2 else _dim2
if _dim1 > _dim2:
# broadcast
shape2 = f(
builtin.Concat(axis=0, comp_node=device),
build_shape_head(shape1, idx=-_dim2), # shape1[:-_dim2]
shape2,
)
inp2 = f(builtin.Broadcast(), inp2, shape2)
batch_shape = build_shape_head(shape1)
if _dim2 > _dim1:
# broadcast
shape1 = f(
builtin.Concat(axis=0, comp_node=device),
build_shape_head(shape2, idx=-_dim1), # shape2[:-_dim1]
shape1,
)
inp1 = f(builtin.Broadcast(), inp1, shape1)
batch_shape = build_shape_head(shape2)
if _dim1 == _dim2:
batch_shape = build_shape_head(shape1)

# compress inputs to 3d
if maxdim > 3:
inp1 = f(
builtin.Reshape(),
inp1,
f(
builtin.Concat(axis=0, comp_node=device),
f(builtin.Reduce(mode="product", axis=0), batch_shape),
build_shape_tail(shape1),
),
)
inp2 = f(
builtin.Reshape(),
inp2,
f(
builtin.Concat(axis=0, comp_node=device),
f(builtin.Reduce(mode="product", axis=0), batch_shape),
build_shape_tail(shape2),
),
)
op = builtin.BatchedMatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
strategy=strategy.value,
)
result = f(op, inp1, inp2)

if maxdim > 3:
result = f(
builtin.Reshape(),
result,
f(
builtin.Concat(axis=0, comp_node=device),
batch_shape,
build_shape_tail(f(GetVarShape(), result)),
),
)
if remove_row:
result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result)
if remove_col:
result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result)
return (result,), (True,)

return extentedBatchedMatrixMulOp


def matmul( def matmul(
inp1: Tensor, inp1: Tensor,
inp2: Tensor, inp2: Tensor,
@@ -1067,50 +838,7 @@ def matmul(
[[10. 13.] [[10. 13.]
[28. 40.]] [28. 40.]]
""" """
if amp._enabled:
compute_mode = "float32"
inp1, inp2 = cast_tensors(inp1, inp2)
else:
dtype = dtype_promotion(inp1, inp2)
if inp1.dtype != dtype:
inp1 = inp1.astype(dtype)
if inp2.dtype != dtype:
inp2 = inp2.astype(dtype)

dim1, dim2 = inp1.ndim, inp2.ndim
assert dim1 > 0 and dim2 > 0
maxdim = dim1 if dim1 > dim2 else dim2
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
if dim1 == 1 and dim2 == 1: # dispatch to Dot
return dot(inp1, inp2)
elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul
extentedMatrixMulOp = _get_extentedMatrixMulOp(
inp1.device,
inp1.dtype,
dim1,
dim2,
transpose_a,
transpose_b,
compute_mode,
format,
strategy=_Hashable(get_execution_strategy()),
)
(result,) = apply(extentedMatrixMulOp(), inp1, inp2)
return result
else: # dispath to BatchedMatrixMul
extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp(
inp1.device,
inp1.dtype,
dim1,
dim2,
transpose_a,
transpose_b,
compute_mode,
format,
strategy=_Hashable(get_execution_strategy()),
)
(result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2)
return result
return _matmul(inp1, inp2, transpose_a, transpose_b, compute_mode, format)




def dot(inp1: Tensor, inp2: Tensor) -> Tensor: def dot(inp1: Tensor, inp2: Tensor) -> Tensor:


+ 6
- 3
imperative/python/test/unit/core/test_tensor_wrapper.py View File

@@ -46,14 +46,17 @@ def test_literal_arith(is_varnode):




@pytest.mark.parametrize("is_varnode", [True, False]) @pytest.mark.parametrize("is_varnode", [True, False])
def test_matmul(is_varnode):
@pytest.mark.parametrize(
"shape_a, shape_b", [((4,), (4,)), ((10, 4), (4, 10)), ((3, 10, 4), (3, 4, 10)),],
)
def test_matmul(is_varnode, shape_a, shape_b):
if is_varnode: if is_varnode:
network = Network() network = Network()
else: else:
network = None network = None


A = make_tensor(np.random.rand(5, 7).astype("float32"), network)
B = make_tensor(np.random.rand(7, 10).astype("float32"), network)
A = make_tensor(np.random.rand(*shape_a).astype("float32"), network)
B = make_tensor(np.random.rand(*shape_b).astype("float32"), network)
C = A @ B C = A @ B
if is_varnode: if is_varnode:
np.testing.assert_almost_equal( np.testing.assert_almost_equal(


Loading…
Cancel
Save