diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index 2db8efb3..92a080c8 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -12,6 +12,7 @@ from typing import Union import numpy as np +from .. import _config from .._imperative_rt.common import CompNode from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion from ..ops import builtin @@ -87,6 +88,7 @@ def _matmul(inp1, inp2): inp1 = inp1.astype(dtype) if inp2.dtype != dtype: inp2 = inp2.astype(dtype) + compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) op = builtin.MatrixMul( transposeA=False, transposeB=False, compute_mode=compute_mode, format="default" ) diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index 16adc50a..440caafd 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -11,6 +11,7 @@ import math from functools import lru_cache from typing import Optional, Sequence, Tuple, Union +from ..core import _config from ..core._imperative_rt.core2 import apply, dtype_promotion from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from ..core._trace_option import use_symbolic_shape @@ -1077,6 +1078,7 @@ def matmul( dim1, dim2 = inp1.ndim, inp2.ndim assert dim1 > 0 and dim2 > 0 maxdim = dim1 if dim1 > dim2 else dim2 + compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) if dim1 == 1 and dim2 == 1: # dispatch to Dot return dot(inp1, inp2) elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index e741bdf8..5400f8a4 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -10,6 +10,7 @@ from functools import lru_cache from typing import NamedTuple, Optional, Sequence, Tuple, Union +from ..core import _config from ..core._imperative_rt.core2 import apply, dtype_promotion from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from ..core.ops import builtin @@ -115,6 +116,7 @@ def linear( weight: weight with shape `(out_features, in_features)`. bias: bias with shape `(out_features,)`. Default: None """ + compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) ret = matmul(inp, weight, transpose_b=True, compute_mode=compute_mode) if bias is not None: if amp._enabled: @@ -185,6 +187,8 @@ def conv1d( pad_h = padding dilate_h = dilation + compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) + conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) sparse_type = "dense" if groups == 1 else "group" op = builtin.Convolution( stride_h=stride_h, @@ -197,6 +201,7 @@ def conv1d( mode=conv_mode, compute_mode=compute_mode, sparse=sparse_type, + format=conv_format, ) (output,) = apply(op, inp, weight) if bias is not None: @@ -261,6 +266,8 @@ def conv2d( dilate_h, dilate_w = expand_hw(dilation) sparse_type = "dense" if groups == 1 else "group" + compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) + conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) op = builtin.Convolution( stride_h=stride_h, stride_w=stride_w, @@ -272,6 +279,7 @@ def conv2d( mode=conv_mode, compute_mode=compute_mode, sparse=sparse_type, + format=conv_format, ) (output,) = apply(op, inp, weight) if bias is not None: @@ -403,6 +411,7 @@ def conv_transpose2d( stride_h, stride_w = expand_hw(stride) pad_h, pad_w = expand_hw(padding) dilate_h, dilate_w = expand_hw(dilation) + compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) op = builtin.ConvolutionBackwardData( stride_h=stride_h, @@ -474,6 +483,7 @@ def deformable_conv2d( pad_h, pad_w = expand_hw(padding) dilate_h, dilate_w = expand_hw(dilation) + compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) sparse_type = "dense" if groups == 1 else "group" op = builtin.DeformableConv( stride_h=stride_h, @@ -614,6 +624,7 @@ def max_pool2d( window_h, window_w = _pair_nonzero(kernel_size) stride_h, stride_w = _pair_nonzero(stride) padding_h, padding_w = _pair(padding) + conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) op = builtin.Pooling( window_h=window_h, @@ -623,6 +634,7 @@ def max_pool2d( pad_h=padding_h, pad_w=padding_w, mode="max", + format=conv_format, ) (output,) = apply(op, inp) return output @@ -656,6 +668,7 @@ def avg_pool2d( window_h, window_w = _pair_nonzero(kernel_size) stride_h, stride_w = _pair_nonzero(stride) padding_h, padding_w = _pair(padding) + conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) op = builtin.Pooling( window_h=window_h, @@ -665,6 +678,7 @@ def avg_pool2d( pad_h=padding_h, pad_w=padding_w, mode=mode, + format=conv_format, ) (output,) = apply(op, inp) return output @@ -686,8 +700,9 @@ def adaptive_max_pool2d( """ if isinstance(oshp, int): oshp = (oshp, oshp) + conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) - op = builtin.AdaptivePooling(mode="max", format="NCHW",) + op = builtin.AdaptivePooling(mode="max", format=conv_format,) oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device) (output,) = apply(op, inp, oshp) return output diff --git a/imperative/python/megengine/functional/quantized.py b/imperative/python/megengine/functional/quantized.py index d004b762..dac4ef61 100644 --- a/imperative/python/megengine/functional/quantized.py +++ b/imperative/python/megengine/functional/quantized.py @@ -8,6 +8,7 @@ # pylint: disable=too-many-lines from typing import Tuple, Union +from ..core import _config from ..core._imperative_rt.core2 import apply from ..core.ops import builtin from ..tensor import Tensor @@ -55,6 +56,8 @@ def conv_bias_activation( sh, sw = _pair_nonzero(stride) dh, dw = _pair_nonzero(dilation) sparse_type = "dense" if groups == 1 else "group" + compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) + conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) op = builtin.ConvBias( stride_h=sh, stride_w=sw, @@ -63,7 +66,7 @@ def conv_bias_activation( dilate_h=dh, dilate_w=dw, dtype=dtype, - format="NCHW", + format=conv_format, strategy=get_execution_strategy(), nonlineMode=nonlinear_mode, mode=conv_mode, @@ -114,6 +117,8 @@ def batch_conv_bias_activation( sh, sw = _pair_nonzero(stride) dh, dw = _pair_nonzero(dilation) sparse_type = "dense" if groups == 1 else "group" + compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) + conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) op = builtin.BatchConvBias( stride_h=sh, stride_w=sw, @@ -122,7 +127,7 @@ def batch_conv_bias_activation( dilate_h=dh, dilate_w=dw, dtype=dtype, - format="NCHW", + format=conv_format, strategy=get_execution_strategy(), nonlineMode=nonlinear_mode, mode=conv_mode, @@ -164,6 +169,7 @@ def conv_transpose2d( pad_h, pad_w = _pair(padding) stride_h, stride_w = _pair_nonzero(stride) dilate_h, dilate_w = _pair_nonzero(dilation) + compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) # should be replaced by Op with bias such as ConvolutionBackwardDataBias op = builtin.ConvolutionBackwardData( diff --git a/imperative/python/megengine/functional/vision.py b/imperative/python/megengine/functional/vision.py index a0a2e197..03169950 100644 --- a/imperative/python/megengine/functional/vision.py +++ b/imperative/python/megengine/functional/vision.py @@ -10,6 +10,7 @@ from typing import Iterable, Optional, Tuple, Union import numpy as np +from ..core import _config from ..core._imperative_rt.core2 import apply from ..core.ops import builtin from ..core.tensor import megbrain_graph, utils @@ -143,9 +144,11 @@ def correlation( pad_size: int (non-negative), optional, default=0) – pad for Correlation is_multiply: boolean, optional, default=True) – operation type is either multiplication or absolute difference """ + conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) + assert conv_format == "NCHW", "Currently correlation only support NCHW mode" op = builtin.Correlation( - format="NCHW", + format=conv_format, kernel_size=kernel_size, max_displacement=max_displacement, stride1=stride1, @@ -215,10 +218,12 @@ def roi_align( sample_points = (sample_points, sample_points) sample_height, sample_width = sample_points offset = 0.5 if aligned else 0.0 + conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) + assert conv_format == "NCHW", "Currently roi_align only support NCHW mode" op = builtin.ROIAlign( mode=mode, - format="NCHW", + format=conv_format, spatial_scale=spatial_scale, offset=offset, pooled_height=pooled_height, @@ -343,9 +348,10 @@ def remap( [[[[1. 4.] [4. 4.]]]] """ + conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) op = builtin.Remap( - imode=interp_mode, border_type=border_mode, format="NCHW", scalar=scalar + imode=interp_mode, border_type=border_mode, format=conv_format, scalar=scalar ) assert isinstance(inp, (Tensor, megbrain_graph.VarNode)), "inp must be Tensor type" (result,) = apply(op, inp, map_xy) @@ -384,10 +390,12 @@ def warp_affine( however it does not mean that you can use all the combinations. On different platforms, different combinations are supported. """ + conv_format = _config._get_actual_op_param(format, _config.__conv_format) + op = builtin.WarpAffine( border_mode=border_mode, border_val=border_val, - format=format, + format=conv_format, imode=interp_mode, ) out_shape = utils.astensor1d(out_shape, inp, dtype="int32", device=inp.device) @@ -466,8 +474,9 @@ def warp_perspective( mat = mat.astype("float32") if inp.dtype == np.float16: inp = inp.astype("float32") + conv_format = _config._get_actual_op_param(format, _config.__conv_format) op = builtin.WarpPerspective( - imode=interp_mode, bmode=border_mode, format=format, border_val=border_val + imode=interp_mode, bmode=border_mode, format=conv_format, border_val=border_val ) out_shape = astensor1d(out_shape, inp, dtype="int32", device=inp.device) if mat_idx is not None: @@ -602,7 +611,9 @@ def interpolate( } if inp.dtype == np.float16: inp = inp.astype("float32") - op = builtin.Resize(imode=mode_map[mode], format="NCHW") + conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) + assert conv_format == "NCHW", "Currently resize only support NCHW mode" + op = builtin.Resize(imode=mode_map[mode], format=conv_format) shape = astensor1d(dsize, inp, dtype="int32", device=inp.device) (ret,) = apply(op, inp, shape) else: