Browse Source

fix(imperative): remove convert_inputs

GitOrigin-RevId: a3c43db746
release-1.5
Megvii Engine Team huangxinda 3 years ago
parent
commit
070c811732
14 changed files with 81 additions and 95 deletions
  1. +6
    -4
      imperative/python/megengine/core/tensor/array_method.py
  2. +1
    -1
      imperative/python/megengine/functional/elemwise.py
  3. +9
    -7
      imperative/python/megengine/functional/math.py
  4. +5
    -30
      imperative/python/megengine/functional/nn.py
  5. +14
    -16
      imperative/python/megengine/functional/tensor.py
  6. +9
    -4
      imperative/python/megengine/functional/vision.py
  7. +5
    -6
      imperative/python/test/integration/test_bn.py
  8. +2
    -3
      imperative/python/test/integration/test_converge.py
  9. +2
    -3
      imperative/python/test/integration/test_converge_with_gradient_clip.py
  10. +2
    -3
      imperative/python/test/integration/test_converge_with_swap_and_drop.py
  11. +1
    -1
      imperative/python/test/unit/core/test_interpreter.py
  12. +12
    -1
      imperative/python/test/unit/functional/test_elemwise.py
  13. +12
    -12
      imperative/python/test/unit/functional/test_functional.py
  14. +1
    -4
      imperative/python/test/unit/optimizer/test_clip_grad.py

+ 6
- 4
imperative/python/megengine/core/tensor/array_method.py View File

@@ -13,7 +13,7 @@ from typing import Union
import numpy as np import numpy as np


from .._imperative_rt.common import CompNode from .._imperative_rt.common import CompNode
from .._imperative_rt.core2 import SymbolVar, Tensor, apply
from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion
from ..ops import builtin from ..ops import builtin
from . import amp from . import amp
from .indexing import getitem, setitem from .indexing import getitem, setitem
@@ -81,7 +81,11 @@ def _matmul(inp1, inp2):
inp1, inp2 = cast_tensors(inp1, inp2) inp1, inp2 = cast_tensors(inp1, inp2)
else: else:
compute_mode = "default" compute_mode = "default"
inp1, inp2 = convert_inputs(inp1, inp2)
dtype = dtype_promotion(inp1, inp2)
if inp1.dtype != dtype:
inp1 = inp1.astype(dtype)
if inp2.dtype != dtype:
inp2 = inp2.astype(dtype)
op = builtin.MatrixMul( op = builtin.MatrixMul(
transposeA=False, transposeB=False, compute_mode=compute_mode, format="default" transposeA=False, transposeB=False, compute_mode=compute_mode, format="default"
) )
@@ -91,7 +95,6 @@ def _matmul(inp1, inp2):


def _transpose(data, axes): def _transpose(data, axes):
op = builtin.Dimshuffle(axes) op = builtin.Dimshuffle(axes)
(data,) = convert_inputs(data)
(result,) = apply(op, data) (result,) = apply(op, data)
return result return result


@@ -201,7 +204,6 @@ def _remove_axis(inp: Tensor, axis) -> Tensor:
def _reduce(mode): def _reduce(mode):
def f(self, axis=None, keepdims: bool = False): def f(self, axis=None, keepdims: bool = False):
data = self data = self
(data,) = convert_inputs(data)
if mode == "mean": if mode == "mean":
data = data.astype("float32") data = data.astype("float32")
elif self.dtype == np.bool_: elif self.dtype == np.bool_:


+ 1
- 1
imperative/python/megengine/functional/elemwise.py View File

@@ -13,7 +13,7 @@ from ..core._imperative_rt.core2 import SymbolVar, apply
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.builtin import Elemwise from ..core.ops.builtin import Elemwise
from ..core.tensor.array_method import _elwise from ..core.tensor.array_method import _elwise
from ..core.tensor.utils import astype, convert_inputs
from ..core.tensor.utils import convert_inputs
from ..tensor import Tensor from ..tensor import Tensor
from ..utils.deprecation import deprecated_func from ..utils.deprecation import deprecated_func




+ 9
- 7
imperative/python/megengine/functional/math.py View File

@@ -10,16 +10,16 @@ import collections
import math import math
from typing import Optional, Sequence, Tuple, Union from typing import Optional, Sequence, Tuple, Union


from ..core._imperative_rt.core2 import apply
from ..core._imperative_rt.core2 import apply, dtype_promotion
from ..core._trace_option import use_symbolic_shape from ..core._trace_option import use_symbolic_shape
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import amp from ..core.tensor import amp
from ..core.tensor.utils import _normalize_axis, cast_tensors, convert_inputs, setscalar
from ..core.tensor.utils import _normalize_axis, cast_tensors, setscalar
from ..tensor import Tensor from ..tensor import Tensor
from .debug_param import get_execution_strategy from .debug_param import get_execution_strategy
from .elemwise import clip, exp, log, log1p
from .tensor import broadcast_to, concat, expand_dims, reshape, squeeze
from .elemwise import clip
from .tensor import broadcast_to, concat, expand_dims, squeeze


__all__ = [ __all__ = [
"argmax", "argmax",
@@ -816,10 +816,13 @@ def matmul(
compute_mode = "float32" compute_mode = "float32"
inp1, inp2 = cast_tensors(inp1, inp2) inp1, inp2 = cast_tensors(inp1, inp2)
else: else:
inp1, inp2 = convert_inputs(inp1, inp2)
dtype = dtype_promotion(inp1, inp2)
if inp1.dtype != dtype:
inp1 = inp1.astype(dtype)
if inp2.dtype != dtype:
inp2 = inp2.astype(dtype)


remove_row, remove_col = False, False remove_row, remove_col = False, False

dim1, dim2 = inp1.ndim, inp2.ndim dim1, dim2 = inp1.ndim, inp2.ndim
# handle dim=1 cases, dot and matrix-vector multiplication # handle dim=1 cases, dot and matrix-vector multiplication
if dim1 == 1 and dim2 == 1: if dim1 == 1 and dim2 == 1:
@@ -931,7 +934,6 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor:


""" """
op = builtin.Dot() op = builtin.Dot()
inp1, inp2 = convert_inputs(inp1, inp2)
assert ( assert (
inp1.ndim <= 1 and inp2.ndim <= 1 inp1.ndim <= 1 and inp2.ndim <= 1
), "Input tensors for dot must be 1-dimensional or scalar" ), "Input tensors for dot must be 1-dimensional or scalar"


+ 5
- 30
imperative/python/megengine/functional/nn.py View File

@@ -10,8 +10,6 @@
from typing import Optional, Sequence, Tuple, Union from typing import Optional, Sequence, Tuple, Union


from ..core._imperative_rt.core2 import apply from ..core._imperative_rt.core2 import apply
from ..core._imperative_rt.graph import VarNode
from ..core._trace_option import use_symbolic_shape
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.builtin import BatchNorm, Elemwise from ..core.ops.builtin import BatchNorm, Elemwise
from ..core.ops.special import Const from ..core.ops.special import Const
@@ -21,7 +19,6 @@ from ..core.tensor.utils import (
astensor1d, astensor1d,
astype, astype,
cast_tensors, cast_tensors,
convert_inputs,
convert_single_value, convert_single_value,
setscalar, setscalar,
) )
@@ -33,18 +30,9 @@ from ..utils.deprecation import deprecated_func
from ..utils.tuple_function import _pair, _pair_nonzero, _triple, _triple_nonzero from ..utils.tuple_function import _pair, _pair_nonzero, _triple, _triple_nonzero
from .debug_param import get_execution_strategy from .debug_param import get_execution_strategy
from .distributed import all_reduce_sum from .distributed import all_reduce_sum
from .elemwise import _elwise, exp, floor, log, log1p, maximum, minimum
from .math import argsort, matmul, max, prod, sum
from .tensor import (
broadcast_to,
concat,
expand_dims,
full,
ones,
reshape,
squeeze,
zeros,
)
from .elemwise import _elwise, exp, log, log1p, maximum, minimum
from .math import matmul, max, sum
from .tensor import broadcast_to, concat, expand_dims, ones, squeeze, zeros


__all__ = [ __all__ = [
"adaptive_avg_pool2d", "adaptive_avg_pool2d",
@@ -167,8 +155,6 @@ def conv1d(
if amp._enabled: if amp._enabled:
compute_mode = "float32" compute_mode = "float32"
inp, weight, bias = cast_tensors(inp, weight, bias) inp, weight, bias = cast_tensors(inp, weight, bias)
else:
inp, weight = convert_inputs(inp, weight)


inp = expand_dims(inp, 3) inp = expand_dims(inp, 3)
weight = expand_dims(weight, 3) weight = expand_dims(weight, 3)
@@ -246,8 +232,6 @@ def conv2d(
if amp._enabled: if amp._enabled:
compute_mode = "float32" compute_mode = "float32"
inp, weight, bias = cast_tensors(inp, weight, bias) inp, weight, bias = cast_tensors(inp, weight, bias)
else:
inp, weight = convert_inputs(inp, weight)


stride_h, stride_w = expand_hw(stride) stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding) pad_h, pad_w = expand_hw(padding)
@@ -304,7 +288,6 @@ def conv3d(
:return: output tensor. :return: output tensor.
""" """
assert conv_mode.lower() == "cross_correlation" assert conv_mode.lower() == "cross_correlation"
inp, weight = convert_inputs(inp, weight)


D, H, W = 0, 1, 2 D, H, W = 0, 1, 2


@@ -379,8 +362,6 @@ def conv_transpose2d(
if amp._enabled: if amp._enabled:
compute_mode = "float32" compute_mode = "float32"
inp, weight, bias = cast_tensors(inp, weight, bias) inp, weight, bias = cast_tensors(inp, weight, bias)
else:
inp, weight = convert_inputs(inp, weight)


if groups != 1: if groups != 1:
raise NotImplementedError("group transposed conv2d is not supported yet.") raise NotImplementedError("group transposed conv2d is not supported yet.")
@@ -454,7 +435,8 @@ def deformable_conv2d(
compute_mode = "float32" compute_mode = "float32"
inp, weight, offset, mask, bias = cast_tensors(inp, weight, offset, mask, bias) inp, weight, offset, mask, bias = cast_tensors(inp, weight, offset, mask, bias)
else: else:
inp, weight, offset, mask = convert_inputs(inp, weight, offset, mask)
offset = offset.astype("float32")
mask = mask.astype("float32")


stride_h, stride_w = expand_hw(stride) stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding) pad_h, pad_w = expand_hw(padding)
@@ -493,7 +475,6 @@ def local_conv2d(
conv_mode.lower() == "cross_correlation" conv_mode.lower() == "cross_correlation"
or conv_mode.name == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION"
) )
inp, weight = convert_inputs(inp, weight)


stride_h, stride_w = expand_hw(stride) stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding) pad_h, pad_w = expand_hw(padding)
@@ -539,8 +520,6 @@ def conv_transpose3d(
:param dilation: dilation of the 3D convolution operation. Default: 1 :param dilation: dilation of the 3D convolution operation. Default: 1
:return: output tensor. :return: output tensor.
""" """
inp, weight = convert_inputs(inp, weight)

D, H, W = 0, 1, 2 D, H, W = 0, 1, 2
pad = _triple(padding) pad = _triple(padding)
stride = _triple_nonzero(stride) stride = _triple_nonzero(stride)
@@ -1078,10 +1057,6 @@ def batch_norm(
weight, bias, running_mean, running_var = cast_tensors( weight, bias, running_mean, running_var = cast_tensors(
weight, bias, running_mean, running_var, promote=True weight, bias, running_mean, running_var, promote=True
) )
elif compute_mode != "float32":
inp, weight, bias, running_mean, running_var = convert_inputs(
inp, weight, bias, running_mean, running_var
)
weight = make_full_if_none(weight, 1) weight = make_full_if_none(weight, 1)
bias = make_full_if_none(bias, 0) bias = make_full_if_none(bias, 0)




+ 14
- 16
imperative/python/megengine/functional/tensor.py View File

@@ -6,25 +6,18 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import math
from typing import Iterable, Optional, Sequence, Union from typing import Iterable, Optional, Sequence, Union


import numpy as np import numpy as np


from ..core._imperative_rt import CompNode from ..core._imperative_rt import CompNode
from ..core._imperative_rt.core2 import SymbolVar, apply
from ..core._imperative_rt.core2 import SymbolVar, apply, dtype_promotion
from ..core._wrap import as_device from ..core._wrap import as_device
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.builtin import Copy, Identity from ..core.ops.builtin import Copy, Identity
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor.array_method import _broadcast, _remove_axis from ..core.tensor.array_method import _broadcast, _remove_axis
from ..core.tensor.utils import (
astensor1d,
convert_inputs,
convert_single_value,
dtype_promotion,
get_device,
)
from ..core.tensor.utils import astensor1d, convert_inputs, get_device
from ..device import get_default_device from ..device import get_default_device
from ..tensor import Tensor from ..tensor import Tensor
from .elemwise import ceil, floor_div from .elemwise import ceil, floor_div
@@ -288,6 +281,7 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
if len(inps) == 1: if len(inps) == 1:
return inps[0] return inps[0]


# FIXME: remove this convert_inputs
inps = convert_inputs(*inps, device=device) inps = convert_inputs(*inps, device=device)
if device is None: if device is None:
device = get_device(inps) device = get_device(inps)
@@ -640,6 +634,7 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:


.. testcode:: .. testcode::


import numpy as np
from megengine import tensor from megengine import tensor
import megengine.functional as F import megengine.functional as F
mask = tensor(np.array([[True, False], [False, True]], dtype=np.bool)) mask = tensor(np.array([[True, False], [False, True]], dtype=np.bool))
@@ -657,7 +652,6 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
[7. 4.]] [7. 4.]]
""" """


x, y = convert_inputs(x, y)
if not isinstance(x, Tensor): if not isinstance(x, Tensor):
raise TypeError("input x must be a tensor") raise TypeError("input x must be a tensor")
if not isinstance(y, Tensor): if not isinstance(y, Tensor):
@@ -669,6 +663,12 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
if x.device != mask.device: if x.device != mask.device:
raise ValueError("ambiguous device: {} vs {}".format(x.device, mask.device)) raise ValueError("ambiguous device: {} vs {}".format(x.device, mask.device))


dtype = dtype_promotion(x, y)
if x.dtype != dtype:
x = x.astype(dtype)
if y.dtype != dtype:
y = y.astype(dtype)

v0, index0 = cond_take(mask, x) v0, index0 = cond_take(mask, x)
v1, index1 = cond_take(~mask, y) v1, index1 = cond_take(~mask, y)


@@ -1021,12 +1021,10 @@ def arange(
if stop is None: if stop is None:
start, stop = 0, start start, stop = 0, start


if isinstance(start, Tensor):
start = start.astype("float32")
if isinstance(stop, Tensor):
stop = stop.astype("float32")
if isinstance(step, Tensor):
step = step.astype("float32")
start = Tensor(start, dtype="float32")
stop = Tensor(stop, dtype="float32")
step = Tensor(step, dtype="float32")

num = ceil((stop - start) / step) num = ceil((stop - start) / step)
stop = start + step * (num - 1) stop = start + step * (num - 1)
result = linspace(start, stop, num, device=device) result = linspace(start, stop, num, device=device)


+ 9
- 4
imperative/python/megengine/functional/vision.py View File

@@ -8,6 +8,8 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Iterable, Optional, Tuple, Union from typing import Iterable, Optional, Tuple, Union


import numpy as np

from ..core._imperative_rt.core2 import apply from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin from ..core.ops import builtin
from ..core.tensor import megbrain_graph, utils from ..core.tensor import megbrain_graph, utils
@@ -98,7 +100,6 @@ def roi_pooling(
output_shape = (output_shape, output_shape) output_shape = (output_shape, output_shape)


op = builtin.ROIPooling(mode=mode, scale=scale) op = builtin.ROIPooling(mode=mode, scale=scale)
inp, rois = utils.convert_inputs(inp, rois)
result, _ = apply( result, _ = apply(
op, inp, rois, Tensor(output_shape, dtype="int32", device=inp.device) op, inp, rois, Tensor(output_shape, dtype="int32", device=inp.device)
) )
@@ -187,6 +188,8 @@ def roi_align(
[0.1359 0.1359]]] [0.1359 0.1359]]]


""" """
if inp.dtype != np.float32:
inp = inp.astype(np.float32)
mode = mode.lower() mode = mode.lower()
assert mode in ["max", "average"], "only max/average mode is supported" assert mode in ["max", "average"], "only max/average mode is supported"
if isinstance(output_shape, int): if isinstance(output_shape, int):
@@ -207,7 +210,6 @@ def roi_align(
sample_height=sample_height, sample_height=sample_height,
sample_width=sample_width, sample_width=sample_width,
) )
inp, rois = utils.convert_inputs(inp, rois)
result, *_ = apply(op, inp, rois) result, *_ = apply(op, inp, rois)
return result return result


@@ -270,7 +272,7 @@ def nms(
max_output = boxes.shape[0] max_output = boxes.shape[0]


op = builtin.NMSKeep(iou_thresh, max_output) op = builtin.NMSKeep(iou_thresh, max_output)
inp = utils.convert_inputs(boxes.reshape(1, -1, 4))
inp = (boxes.reshape(1, -1, 4),)
indices, count = apply(op, *inp) indices, count = apply(op, *inp)
indices = indices[0][: count[0]] indices = indices[0][: count[0]]
keep_inds = sorted_idx[indices] keep_inds = sorted_idx[indices]
@@ -442,10 +444,13 @@ def warp_perspective(
[ 9. 10.]]]] [ 9. 10.]]]]


""" """
if inp.dtype == np.float32:
mat = mat.astype("float32")
if inp.dtype == np.float16:
inp = inp.astype("float32")
op = builtin.WarpPerspective( op = builtin.WarpPerspective(
imode=interp_mode, bmode=border_mode, format=format, border_val=border_val imode=interp_mode, bmode=border_mode, format=format, border_val=border_val
) )
inp, mat = utils.convert_inputs(inp, mat)
out_shape = astensor1d(out_shape, inp, dtype="int32", device=inp.device) out_shape = astensor1d(out_shape, inp, dtype="int32", device=inp.device)
if mat_idx is not None: if mat_idx is not None:
mat_idx = astensor1d(mat_idx, inp, dtype="int32", device=inp.device) mat_idx = astensor1d(mat_idx, inp, dtype="int32", device=inp.device)


+ 5
- 6
imperative/python/test/integration/test_bn.py View File

@@ -14,8 +14,7 @@ import megengine.autodiff as ad
import megengine.distributed as dist import megengine.distributed as dist
import megengine.functional as F import megengine.functional as F
import megengine.optimizer as optimizer import megengine.optimizer as optimizer
from megengine import Parameter, tensor
from megengine.distributed.helper import get_device_count_by_fork
from megengine import tensor
from megengine.jit import trace from megengine.jit import trace
from megengine.module import BatchNorm2d, Conv2d, Module, Sequential, SyncBatchNorm from megengine.module import BatchNorm2d, Conv2d, Module, Sequential, SyncBatchNorm


@@ -88,7 +87,7 @@ def test_bn_no_track_stat():
optim = optimizer.SGD(m.parameters(), lr=1.0) optim = optimizer.SGD(m.parameters(), lr=1.0)
optim.clear_grad() optim.clear_grad()


data = np.random.random((6, nchannel, 2, 2)).astype("float32")
data = tensor(np.random.random((6, nchannel, 2, 2)).astype("float32"))
with gm: with gm:
loss = m(data).sum() loss = m(data).sum()
gm.backward(loss) gm.backward(loss)
@@ -110,7 +109,7 @@ def test_bn_no_track_stat2():
optim = optimizer.SGD(m.parameters(), lr=1.0) optim = optimizer.SGD(m.parameters(), lr=1.0)
optim.clear_grad() optim.clear_grad()


data = np.random.random((6, nchannel, 2, 2)).astype("float32")
data = tensor(np.random.random((6, nchannel, 2, 2)).astype("float32"))
with gm: with gm:
loss = m(data).sum() loss = m(data).sum()
gm.backward(loss) gm.backward(loss)
@@ -146,7 +145,7 @@ def test_trace_bn_forward_twice():
pred = net(inp) pred = net(inp)
return pred return pred


x = np.ones((1, 1, 32, 32), dtype=np.float32)
x = tensor(np.ones((1, 1, 32, 32), dtype=np.float32))
y = train_bn(x, net=Simple()) y = train_bn(x, net=Simple())
np.testing.assert_equal(y.numpy(), 0) np.testing.assert_equal(y.numpy(), 0)


@@ -194,5 +193,5 @@ def test_trace_several_syncbn(trace_mode):
def test_frozen_bn_no_affine(): def test_frozen_bn_no_affine():
nchannel = 3 nchannel = 3
m = BatchNorm2d(nchannel, freeze=True, affine=False) m = BatchNorm2d(nchannel, freeze=True, affine=False)
data = megengine.Tensor(np.random.random((6, nchannel, 2, 2)).astype("float32"))
data = tensor(np.random.random((6, nchannel, 2, 2)).astype("float32"))
m(data).numpy() m(data).numpy()

+ 2
- 3
imperative/python/test/integration/test_converge.py View File

@@ -9,7 +9,6 @@
import itertools import itertools


import numpy as np import numpy as np
import pytest


import megengine as mge import megengine as mge
import megengine.autodiff as ad import megengine.autodiff as ad
@@ -105,10 +104,10 @@ def test_training_converge():
xx, yy = np.meshgrid(x, x) xx, yy = np.meshgrid(x, x)
xx = xx.reshape((ngrid * ngrid, 1)) xx = xx.reshape((ngrid * ngrid, 1))
yy = yy.reshape((ngrid * ngrid, 1)) yy = yy.reshape((ngrid * ngrid, 1))
data = np.concatenate((xx, yy), axis=1).astype(np.float32)
data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32))


pred = infer(data).numpy() pred = infer(data).numpy()
precision = calculate_precision(data, pred)
precision = calculate_precision(data.numpy(), pred)
assert precision == 1.0, "Test precision must be high enough, get {}".format( assert precision == 1.0, "Test precision must be high enough, get {}".format(
precision precision
) )

+ 2
- 3
imperative/python/test/integration/test_converge_with_gradient_clip.py View File

@@ -9,7 +9,6 @@
import itertools import itertools


import numpy as np import numpy as np
import pytest


import megengine as mge import megengine as mge
import megengine.autodiff as ad import megengine.autodiff as ad
@@ -110,10 +109,10 @@ def test_training_converge():
xx, yy = np.meshgrid(x, x) xx, yy = np.meshgrid(x, x)
xx = xx.reshape((ngrid * ngrid, 1)) xx = xx.reshape((ngrid * ngrid, 1))
yy = yy.reshape((ngrid * ngrid, 1)) yy = yy.reshape((ngrid * ngrid, 1))
data = np.concatenate((xx, yy), axis=1).astype(np.float32)
data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32))


pred = infer(data).numpy() pred = infer(data).numpy()
precision = calculate_precision(data, pred)
precision = calculate_precision(data.numpy(), pred)
print("precision=", precision) print("precision=", precision)
assert precision == 1.0, "Test precision must be high enough, get {}".format( assert precision == 1.0, "Test precision must be high enough, get {}".format(
precision precision


+ 2
- 3
imperative/python/test/integration/test_converge_with_swap_and_drop.py View File

@@ -9,7 +9,6 @@
import itertools import itertools


import numpy as np import numpy as np
import pytest


import megengine as mge import megengine as mge
import megengine.autodiff as ad import megengine.autodiff as ad
@@ -118,10 +117,10 @@ def test_training_converge_with_swap_and_drop():
xx, yy = np.meshgrid(x, x) xx, yy = np.meshgrid(x, x)
xx = xx.reshape((ngrid * ngrid, 1)) xx = xx.reshape((ngrid * ngrid, 1))
yy = yy.reshape((ngrid * ngrid, 1)) yy = yy.reshape((ngrid * ngrid, 1))
data = np.concatenate((xx, yy), axis=1).astype(np.float32)
data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32))


pred = infer(Tensor(data)).numpy() pred = infer(Tensor(data)).numpy()
precision = calculate_precision(data, pred)
precision = calculate_precision(data.numpy(), pred)
assert precision == 1.0, "Test precision must be high enough, get {}".format( assert precision == 1.0, "Test precision must be high enough, get {}".format(
precision precision
) )


+ 1
- 1
imperative/python/test/unit/core/test_interpreter.py View File

@@ -36,7 +36,7 @@ def test_level1_infer_value():
def test_level1_infer_shape_with_unknown(): def test_level1_infer_shape_with_unknown():
config_async_level(2) config_async_level(2)
a = mge.tensor([[1, 2, 2, 3]], dtype="float32") a = mge.tensor([[1, 2, 2, 3]], dtype="float32")
b = mge.tensor([1, 1])
b = mge.tensor([1, 1], dtype="float32")
multi2 = mge.tensor(np.array([[2, 0], [0, 2]]), dtype="float32") multi2 = mge.tensor(np.array([[2, 0], [0, 2]]), dtype="float32")
c = F.matmul(b, multi2) c = F.matmul(b, multi2)
# make DepType::SHAPE unknown # make DepType::SHAPE unknown


+ 12
- 1
imperative/python/test/unit/functional/test_elemwise.py View File

@@ -13,7 +13,7 @@ import megengine.functional as F
import megengine.functional.elemwise as elemwise import megengine.functional.elemwise as elemwise
from megengine import tensor from megengine import tensor
from megengine.core.tensor import dtype from megengine.core.tensor import dtype
from megengine.functional.elemwise import Elemwise, _elwise
from megengine.functional.elemwise import Elemwise
from megengine.jit import trace from megengine.jit import trace




@@ -57,6 +57,17 @@ def test_multiply():
) )




def test_div():
np.testing.assert_allclose(
F.div(tensor([3, 4]), 2).numpy(),
np.divide(np.array([3, 4], dtype=np.float32), 2),
)

np.testing.assert_allclose(
(tensor([3, 4]) / 2).numpy(), np.divide(np.array([3, 4], dtype=np.float32), 2),
)


def test_clamp(): def test_clamp():
"""Fix an issue when `lower` or `upper` is 0, it will be recognized as `False` and """Fix an issue when `lower` or `upper` is 0, it will be recognized as `False` and
`F.clip` will fall into wrong conditions unexpectedly. `F.clip` will fall into wrong conditions unexpectedly.


+ 12
- 12
imperative/python/test/unit/functional/test_functional.py View File

@@ -456,9 +456,10 @@ def test_interpolate_fastpath():
np.testing.assert_equal(out.item(), np_x.mean()) np.testing.assert_equal(out.item(), np_x.mean())




def test_warp_perspective():
@pytest.mark.parametrize("dt", [np.float32, np.int8, np.uint8, np.float16])
def test_warp_perspective(dt):
inp_shape = (1, 1, 4, 4) inp_shape = (1, 1, 4, 4)
x = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
x = tensor(np.arange(16, dtype=dt).reshape(inp_shape))
M_shape = (1, 3, 3) M_shape = (1, 3, 3)
# M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1) # M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1)
M = tensor( M = tensor(
@@ -467,14 +468,13 @@ def test_warp_perspective():
).reshape(M_shape) ).reshape(M_shape)
) )
outp = F.vision.warp_perspective(x, M, (2, 2)) outp = F.vision.warp_perspective(x, M, (2, 2))
np.testing.assert_equal(
outp.numpy(), np.array([[[[5.0, 6.0], [9.0, 10.0]]]], dtype=np.float32)
)
np.testing.assert_equal(outp.numpy(), np.array([[[[5, 6], [9, 10]]]], dtype=dt))




def test_warp_perspective_mat_idx():
@pytest.mark.parametrize("dt", [np.float32, np.int8, np.uint8, np.float16])
def test_warp_perspective_mat_idx(dt):
inp_shape = (2, 1, 4, 4) inp_shape = (2, 1, 4, 4)
x = tensor(np.arange(32, dtype=np.float32).reshape(inp_shape))
x = tensor(np.arange(32, dtype=dt).reshape(inp_shape))
M_shape = (1, 3, 3) M_shape = (1, 3, 3)
# M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1) # M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1)
M = tensor( M = tensor(
@@ -488,12 +488,12 @@ def test_warp_perspective_mat_idx():
outp.numpy(), outp.numpy(),
np.array( np.array(
[ [
[[[5.0, 6.0], [9.0, 10.0]]],
[[[21.0, 22.0], [25.0, 26.0]]],
[[[21.0, 22.0], [25.0, 26.0]]],
[[[5.0, 6.0], [9.0, 10.0]]],
[[[5, 6], [9, 10]]],
[[[21, 22], [25, 26]]],
[[[21, 22], [25, 26]]],
[[[5, 6], [9, 10]]],
], ],
dtype=np.float32,
dtype=dt,
), ),
) )




+ 1
- 4
imperative/python/test/unit/optimizer/test_clip_grad.py View File

@@ -5,11 +5,8 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import platform
import weakref


import numpy as np import numpy as np
import pytest


import megengine as mge import megengine as mge
import megengine.autodiff as ad import megengine.autodiff as ad
@@ -65,7 +62,7 @@ def test_clip_grad_value():
gm = ad.GradManager().attach(net.parameters()) gm = ad.GradManager().attach(net.parameters())
opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9) opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9)
with gm: with gm:
y = net(x)
y = net(mge.tensor(x))
y = y.mean() y = y.mean()
gm.backward(y) gm.backward(y)
save_grad_value(net) save_grad_value(net)


Loading…
Cancel
Save