Browse Source

perf(functional/matmul): reimplement matmul with subgraph

GitOrigin-RevId: 456b2a51d3
tags/v1.6.0-rc1
Megvii Engine Team 3 years ago
parent
commit
3206af9db2
1 changed files with 243 additions and 75 deletions
  1. +243
    -75
      imperative/python/megengine/functional/math.py

+ 243
- 75
imperative/python/megengine/functional/math.py View File

@@ -8,17 +8,21 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections import collections
import math import math
from functools import lru_cache
from typing import Optional, Sequence, Tuple, Union from typing import Optional, Sequence, Tuple, Union


from ..core._imperative_rt.core2 import apply, dtype_promotion from ..core._imperative_rt.core2 import apply, dtype_promotion
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from ..core._trace_option import use_symbolic_shape from ..core._trace_option import use_symbolic_shape
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import amp from ..core.tensor import amp
from ..core.tensor.utils import _normalize_axis, cast_tensors, setscalar
from ..core.tensor.utils import _normalize_axis, cast_tensors, setscalar, subgraph
from ..jit import exclude_from_trace
from ..tensor import Tensor from ..tensor import Tensor
from .debug_param import get_execution_strategy from .debug_param import get_execution_strategy
from .elemwise import clip
from .elemwise import clip, minimum
from .tensor import broadcast_to, concat, expand_dims, squeeze from .tensor import broadcast_to, concat, expand_dims, squeeze


__all__ = [ __all__ = [
@@ -763,6 +767,216 @@ def matinv(inp: Tensor) -> Tensor:
return result return result




@lru_cache(maxsize=None)
def _get_extentedMatrixMulOp(
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy,
):
@subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=3)
def extentedMatrixMulOp(inputs, f, c):
assert len(inputs) == 2
inp1, inp2 = inputs
_dim1, _dim2 = dim1, dim2

def build_shape_head(shape, idx=-1):
# shape[:idx]
return f(
builtin.Subtensor(items=[[0, False, True, False, False]]),
shape,
c(idx, "int32"),
)

def build_shape_tail(shape, idx=-1):
# shape[idx:]
return f(
builtin.Subtensor(items=[[0, True, False, False, False]]),
shape,
c(idx, "int32"),
)

remove_row, remove_col = False, False
if _dim1 == 1:
_dim1 = 2
remove_row = True
if _dim2 == 1:
_dim2 = 2
remove_col = True

if remove_row:
inp1 = f(builtin.AddAxis(axis=[0,]), inp1)
if remove_col:
inp2 = f(builtin.AddAxis(axis=[1,]), inp2)

shape1 = f(GetVarShape(), inp1)
shape2 = f(GetVarShape(), inp2)
if _dim1 > 2:
inp1 = f(
builtin.Reshape(),
inp1,
f(
builtin.Concat(axis=0, comp_node=device),
f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape1)),
build_shape_tail(shape1),
),
)
if _dim2 > 2:
inp2 = f(
builtin.Reshape(),
inp2,
f(
builtin.Concat(axis=0, comp_node=device),
f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape2)),
build_shape_tail(shape2),
),
)
op = builtin.MatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
strategy=strategy,
)
result = f(op, inp1, inp2)
result_shape = f(GetVarShape(), result)
if _dim1 > 2:
result = f(
builtin.Reshape(),
result,
f(
builtin.Concat(axis=0, comp_node=device),
build_shape_head(shape1),
build_shape_tail(result_shape),
),
)
if _dim2 > 2:
result = f(
builtin.Reshape(),
result,
f(
builtin.Concat(axis=0, comp_node=device),
build_shape_head(shape2),
build_shape_tail(result_shape),
),
)
maxdim = _dim1 if _dim1 > _dim2 else _dim2
if remove_row:
result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result)
if remove_col:
result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result)
return (result,), (True,)

return extentedMatrixMulOp


@lru_cache(maxsize=None)
def _get_extentedBatchedMatrixMulOp(
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy,
):
@subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=3)
def extentedBatchedMatrixMulOp(inputs, f, c):
assert len(inputs) == 2
inp1, inp2 = inputs
_dim1, _dim2 = dim1, dim2

def build_shape_head(shape, idx=-2):
# shape[:idx]
return f(
builtin.Subtensor(items=[[0, False, True, False, False]]),
shape,
c(idx, "int32"),
)

def build_shape_tail(shape, idx=-2):
# shape[idx:]
return f(
builtin.Subtensor(items=[[0, True, False, False, False]]),
shape,
c(idx, "int32"),
)

remove_row, remove_col = False, False
if _dim1 == 1:
_dim1 = 2
remove_row = True
if _dim2 == 1:
_dim2 = 2
remove_col = True

if remove_row:
inp1 = f(builtin.AddAxis(axis=[0,]), inp1)
if remove_col:
inp2 = f(builtin.AddAxis(axis=[1,]), inp2)
shape1 = f(GetVarShape(), inp1)
shape2 = f(GetVarShape(), inp2)
maxdim = _dim1 if _dim1 > _dim2 else _dim2
if _dim1 > _dim2:
# broadcast
shape2 = f(
builtin.Concat(axis=0, comp_node=device),
build_shape_head(shape1, idx=-_dim2), # shape1[:-_dim2]
shape2,
)
inp2 = f(builtin.Broadcast(), inp2, shape2)
batch_shape = build_shape_head(shape1)
if _dim2 > _dim1:
# broadcast
shape1 = f(
builtin.Concat(axis=0, comp_node=device),
build_shape_head(shape2, idx=-_dim1), # shape2[:-_dim1]
shape1,
)
inp1 = f(builtin.Broadcast(), inp1, shape1)
batch_shape = build_shape_head(shape2)
if _dim1 == _dim2:
batch_shape = build_shape_head(shape1)

# compress inputs to 3d
if maxdim > 3:
inp1 = f(
builtin.Reshape(),
inp1,
f(
builtin.Concat(axis=0, comp_node=device),
f(builtin.Reduce(mode="product", axis=0), batch_shape),
build_shape_tail(shape1),
),
)
inp2 = f(
builtin.Reshape(),
inp2,
f(
builtin.Concat(axis=0, comp_node=device),
f(builtin.Reduce(mode="product", axis=0), batch_shape),
build_shape_tail(shape2),
),
)
op = builtin.BatchedMatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
strategy=strategy,
)
result = f(op, inp1, inp2)

if maxdim > 3:
result = f(
builtin.Reshape(),
result,
f(
builtin.Concat(axis=0, comp_node=device),
batch_shape,
build_shape_tail(f(GetVarShape(), result)),
),
)
if remove_row:
result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result)
if remove_col:
result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result)
return (result,), (True,)

return extentedBatchedMatrixMulOp


def matmul( def matmul(
inp1: Tensor, inp1: Tensor,
inp2: Tensor, inp2: Tensor,
@@ -822,85 +1036,39 @@ def matmul(
if inp2.dtype != dtype: if inp2.dtype != dtype:
inp2 = inp2.astype(dtype) inp2 = inp2.astype(dtype)


remove_row, remove_col = False, False
dim1, dim2 = inp1.ndim, inp2.ndim dim1, dim2 = inp1.ndim, inp2.ndim
# handle dim=1 cases, dot and matrix-vector multiplication
if dim1 == 1 and dim2 == 1:
return dot(inp1, inp2)
# the underlying matmul op requires input dims to be at least 2
if dim1 == 1:
inp1 = expand_dims(inp1, 0)
dim1 = 2
remove_row = True
if dim2 == 1:
inp2 = expand_dims(inp2, 1)
dim2 = 2
remove_col = True

batch_shape = None
shape1 = inp1.shape
shape2 = inp2.shape

assert dim1 > 0 and dim2 > 0
maxdim = dim1 if dim1 > dim2 else dim2 maxdim = dim1 if dim1 > dim2 else dim2
if dim1 >= 3 or dim2 >= 3:
if use_symbolic_shape():
if dim1 > dim2:
shape2 = concat([shape1[:-2], shape2[-2:]])
inp2 = broadcast_to(inp2, shape2)
if dim1 < dim2:
shape1 = concat([shape2[:-2], shape1[-2:]])
inp1 = broadcast_to(inp1, shape1)
if maxdim > 3:
batch_shape = shape1[:-2]
# compress inputs to 3d
(inp1,) = apply(
builtin.Reshape(), inp1, concat([prod(shape1[:-2]), shape1[-2:]])
)
(inp2,) = apply(
builtin.Reshape(), inp2, concat([prod(shape2[:-2]), shape2[-2:]])
)
else:
if dim1 > dim2:
shape2 = shape1[:-2] + shape2[-2:]
inp2 = broadcast_to(inp2, shape2)
if dim1 < dim2:
shape1 = shape2[:-2] + shape1[-2:]
inp1 = broadcast_to(inp1, shape1)
if maxdim > 3:
batch_shape = shape1[:-2]
# compress inputs to 3d
inp1 = inp1.reshape((-1, shape1[-2], shape1[-1]))
inp2 = inp2.reshape((-1, shape2[-2], shape2[-1]))

op = builtin.BatchedMatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
if dim1 == 1 and dim2 == 1: # dispatch to Dot
return dot(inp1, inp2)
elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul
extentedMatrixMulOp = _get_extentedMatrixMulOp(
inp1.device,
inp1.dtype,
dim1,
dim2,
transpose_a,
transpose_b,
compute_mode,
format,
strategy=get_execution_strategy(), strategy=get_execution_strategy(),
) )
else:
op = builtin.MatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
(result,) = apply(extentedMatrixMulOp, inp1, inp2)
return result
else: # dispath to BatchedMatrixMul
extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp(
inp1.device,
inp1.dtype,
dim1,
dim2,
transpose_a,
transpose_b,
compute_mode,
format,
strategy=get_execution_strategy(), strategy=get_execution_strategy(),
) )

(result,) = apply(op, inp1, inp2)
if maxdim > 3:
if use_symbolic_shape():
(result,) = apply(
builtin.Reshape(), result, concat([batch_shape, result.shape[-2:]])
)
else:
result = result.reshape(batch_shape + result.shape[-2:])
if remove_row:
result = squeeze(result, axis=-2)
if remove_col:
result = squeeze(result, axis=-1)
return result
(result,) = apply(extentedBatchedMatrixMulOp, inp1, inp2)
return result




def dot(inp1: Tensor, inp2: Tensor) -> Tensor: def dot(inp1: Tensor, inp2: Tensor) -> Tensor:


Loading…
Cancel
Save