@@ -12,6 +12,7 @@ from typing import Union | |||
import numpy as np | |||
from .. import _config | |||
from .._imperative_rt.common import CompNode | |||
from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion | |||
from ..ops import builtin | |||
@@ -87,6 +88,7 @@ def _matmul(inp1, inp2): | |||
inp1 = inp1.astype(dtype) | |||
if inp2.dtype != dtype: | |||
inp2 = inp2.astype(dtype) | |||
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
op = builtin.MatrixMul( | |||
transposeA=False, transposeB=False, compute_mode=compute_mode, format="default" | |||
) | |||
@@ -11,6 +11,7 @@ import math | |||
from functools import lru_cache | |||
from typing import Optional, Sequence, Tuple, Union | |||
from ..core import _config | |||
from ..core._imperative_rt.core2 import apply, dtype_promotion | |||
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder | |||
from ..core._trace_option import use_symbolic_shape | |||
@@ -1077,6 +1078,7 @@ def matmul( | |||
dim1, dim2 = inp1.ndim, inp2.ndim | |||
assert dim1 > 0 and dim2 > 0 | |||
maxdim = dim1 if dim1 > dim2 else dim2 | |||
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
if dim1 == 1 and dim2 == 1: # dispatch to Dot | |||
return dot(inp1, inp2) | |||
elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul | |||
@@ -10,6 +10,7 @@ | |||
from functools import lru_cache | |||
from typing import NamedTuple, Optional, Sequence, Tuple, Union | |||
from ..core import _config | |||
from ..core._imperative_rt.core2 import apply, dtype_promotion | |||
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder | |||
from ..core.ops import builtin | |||
@@ -115,6 +116,7 @@ def linear( | |||
weight: weight with shape `(out_features, in_features)`. | |||
bias: bias with shape `(out_features,)`. Default: None | |||
""" | |||
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
ret = matmul(inp, weight, transpose_b=True, compute_mode=compute_mode) | |||
if bias is not None: | |||
if amp._enabled: | |||
@@ -185,6 +187,8 @@ def conv1d( | |||
pad_h = padding | |||
dilate_h = dilation | |||
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) | |||
sparse_type = "dense" if groups == 1 else "group" | |||
op = builtin.Convolution( | |||
stride_h=stride_h, | |||
@@ -197,6 +201,7 @@ def conv1d( | |||
mode=conv_mode, | |||
compute_mode=compute_mode, | |||
sparse=sparse_type, | |||
format=conv_format, | |||
) | |||
(output,) = apply(op, inp, weight) | |||
if bias is not None: | |||
@@ -261,6 +266,8 @@ def conv2d( | |||
dilate_h, dilate_w = expand_hw(dilation) | |||
sparse_type = "dense" if groups == 1 else "group" | |||
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) | |||
op = builtin.Convolution( | |||
stride_h=stride_h, | |||
stride_w=stride_w, | |||
@@ -272,6 +279,7 @@ def conv2d( | |||
mode=conv_mode, | |||
compute_mode=compute_mode, | |||
sparse=sparse_type, | |||
format=conv_format, | |||
) | |||
(output,) = apply(op, inp, weight) | |||
if bias is not None: | |||
@@ -403,6 +411,7 @@ def conv_transpose2d( | |||
stride_h, stride_w = expand_hw(stride) | |||
pad_h, pad_w = expand_hw(padding) | |||
dilate_h, dilate_w = expand_hw(dilation) | |||
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
op = builtin.ConvolutionBackwardData( | |||
stride_h=stride_h, | |||
@@ -474,6 +483,7 @@ def deformable_conv2d( | |||
pad_h, pad_w = expand_hw(padding) | |||
dilate_h, dilate_w = expand_hw(dilation) | |||
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
sparse_type = "dense" if groups == 1 else "group" | |||
op = builtin.DeformableConv( | |||
stride_h=stride_h, | |||
@@ -614,6 +624,7 @@ def max_pool2d( | |||
window_h, window_w = _pair_nonzero(kernel_size) | |||
stride_h, stride_w = _pair_nonzero(stride) | |||
padding_h, padding_w = _pair(padding) | |||
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) | |||
op = builtin.Pooling( | |||
window_h=window_h, | |||
@@ -623,6 +634,7 @@ def max_pool2d( | |||
pad_h=padding_h, | |||
pad_w=padding_w, | |||
mode="max", | |||
format=conv_format, | |||
) | |||
(output,) = apply(op, inp) | |||
return output | |||
@@ -656,6 +668,7 @@ def avg_pool2d( | |||
window_h, window_w = _pair_nonzero(kernel_size) | |||
stride_h, stride_w = _pair_nonzero(stride) | |||
padding_h, padding_w = _pair(padding) | |||
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) | |||
op = builtin.Pooling( | |||
window_h=window_h, | |||
@@ -665,6 +678,7 @@ def avg_pool2d( | |||
pad_h=padding_h, | |||
pad_w=padding_w, | |||
mode=mode, | |||
format=conv_format, | |||
) | |||
(output,) = apply(op, inp) | |||
return output | |||
@@ -686,8 +700,9 @@ def adaptive_max_pool2d( | |||
""" | |||
if isinstance(oshp, int): | |||
oshp = (oshp, oshp) | |||
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) | |||
op = builtin.AdaptivePooling(mode="max", format="NCHW",) | |||
op = builtin.AdaptivePooling(mode="max", format=conv_format,) | |||
oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device) | |||
(output,) = apply(op, inp, oshp) | |||
return output | |||
@@ -8,6 +8,7 @@ | |||
# pylint: disable=too-many-lines | |||
from typing import Tuple, Union | |||
from ..core import _config | |||
from ..core._imperative_rt.core2 import apply | |||
from ..core.ops import builtin | |||
from ..tensor import Tensor | |||
@@ -55,6 +56,8 @@ def conv_bias_activation( | |||
sh, sw = _pair_nonzero(stride) | |||
dh, dw = _pair_nonzero(dilation) | |||
sparse_type = "dense" if groups == 1 else "group" | |||
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) | |||
op = builtin.ConvBias( | |||
stride_h=sh, | |||
stride_w=sw, | |||
@@ -63,7 +66,7 @@ def conv_bias_activation( | |||
dilate_h=dh, | |||
dilate_w=dw, | |||
dtype=dtype, | |||
format="NCHW", | |||
format=conv_format, | |||
strategy=get_execution_strategy(), | |||
nonlineMode=nonlinear_mode, | |||
mode=conv_mode, | |||
@@ -114,6 +117,8 @@ def batch_conv_bias_activation( | |||
sh, sw = _pair_nonzero(stride) | |||
dh, dw = _pair_nonzero(dilation) | |||
sparse_type = "dense" if groups == 1 else "group" | |||
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) | |||
op = builtin.BatchConvBias( | |||
stride_h=sh, | |||
stride_w=sw, | |||
@@ -122,7 +127,7 @@ def batch_conv_bias_activation( | |||
dilate_h=dh, | |||
dilate_w=dw, | |||
dtype=dtype, | |||
format="NCHW", | |||
format=conv_format, | |||
strategy=get_execution_strategy(), | |||
nonlineMode=nonlinear_mode, | |||
mode=conv_mode, | |||
@@ -164,6 +169,7 @@ def conv_transpose2d( | |||
pad_h, pad_w = _pair(padding) | |||
stride_h, stride_w = _pair_nonzero(stride) | |||
dilate_h, dilate_w = _pair_nonzero(dilation) | |||
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
# should be replaced by Op with bias such as ConvolutionBackwardDataBias | |||
op = builtin.ConvolutionBackwardData( | |||
@@ -10,6 +10,7 @@ from typing import Iterable, Optional, Tuple, Union | |||
import numpy as np | |||
from ..core import _config | |||
from ..core._imperative_rt.core2 import apply | |||
from ..core.ops import builtin | |||
from ..core.tensor import megbrain_graph, utils | |||
@@ -143,9 +144,11 @@ def correlation( | |||
pad_size: int (non-negative), optional, default=0) – pad for Correlation | |||
is_multiply: boolean, optional, default=True) – operation type is either multiplication or absolute difference | |||
""" | |||
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) | |||
assert conv_format == "NCHW", "Currently correlation only support NCHW mode" | |||
op = builtin.Correlation( | |||
format="NCHW", | |||
format=conv_format, | |||
kernel_size=kernel_size, | |||
max_displacement=max_displacement, | |||
stride1=stride1, | |||
@@ -215,10 +218,12 @@ def roi_align( | |||
sample_points = (sample_points, sample_points) | |||
sample_height, sample_width = sample_points | |||
offset = 0.5 if aligned else 0.0 | |||
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) | |||
assert conv_format == "NCHW", "Currently roi_align only support NCHW mode" | |||
op = builtin.ROIAlign( | |||
mode=mode, | |||
format="NCHW", | |||
format=conv_format, | |||
spatial_scale=spatial_scale, | |||
offset=offset, | |||
pooled_height=pooled_height, | |||
@@ -343,9 +348,10 @@ def remap( | |||
[[[[1. 4.] | |||
[4. 4.]]]] | |||
""" | |||
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) | |||
op = builtin.Remap( | |||
imode=interp_mode, border_type=border_mode, format="NCHW", scalar=scalar | |||
imode=interp_mode, border_type=border_mode, format=conv_format, scalar=scalar | |||
) | |||
assert isinstance(inp, (Tensor, megbrain_graph.VarNode)), "inp must be Tensor type" | |||
(result,) = apply(op, inp, map_xy) | |||
@@ -384,10 +390,12 @@ def warp_affine( | |||
however it does not mean that you can use all the combinations. | |||
On different platforms, different combinations are supported. | |||
""" | |||
conv_format = _config._get_actual_op_param(format, _config.__conv_format) | |||
op = builtin.WarpAffine( | |||
border_mode=border_mode, | |||
border_val=border_val, | |||
format=format, | |||
format=conv_format, | |||
imode=interp_mode, | |||
) | |||
out_shape = utils.astensor1d(out_shape, inp, dtype="int32", device=inp.device) | |||
@@ -466,8 +474,9 @@ def warp_perspective( | |||
mat = mat.astype("float32") | |||
if inp.dtype == np.float16: | |||
inp = inp.astype("float32") | |||
conv_format = _config._get_actual_op_param(format, _config.__conv_format) | |||
op = builtin.WarpPerspective( | |||
imode=interp_mode, bmode=border_mode, format=format, border_val=border_val | |||
imode=interp_mode, bmode=border_mode, format=conv_format, border_val=border_val | |||
) | |||
out_shape = astensor1d(out_shape, inp, dtype="int32", device=inp.device) | |||
if mat_idx is not None: | |||
@@ -602,7 +611,9 @@ def interpolate( | |||
} | |||
if inp.dtype == np.float16: | |||
inp = inp.astype("float32") | |||
op = builtin.Resize(imode=mode_map[mode], format="NCHW") | |||
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) | |||
assert conv_format == "NCHW", "Currently resize only support NCHW mode" | |||
op = builtin.Resize(imode=mode_map[mode], format=conv_format) | |||
shape = astensor1d(dsize, inp, dtype="int32", device=inp.device) | |||
(ret,) = apply(op, inp, shape) | |||
else: | |||