Browse Source

fix(imperative): remove convert_inputs

GitOrigin-RevId: a3c43db746
release-1.5
Megvii Engine Team huangxinda 3 years ago
parent
commit
070c811732
14 changed files with 81 additions and 95 deletions
  1. +6
    -4
      imperative/python/megengine/core/tensor/array_method.py
  2. +1
    -1
      imperative/python/megengine/functional/elemwise.py
  3. +9
    -7
      imperative/python/megengine/functional/math.py
  4. +5
    -30
      imperative/python/megengine/functional/nn.py
  5. +14
    -16
      imperative/python/megengine/functional/tensor.py
  6. +9
    -4
      imperative/python/megengine/functional/vision.py
  7. +5
    -6
      imperative/python/test/integration/test_bn.py
  8. +2
    -3
      imperative/python/test/integration/test_converge.py
  9. +2
    -3
      imperative/python/test/integration/test_converge_with_gradient_clip.py
  10. +2
    -3
      imperative/python/test/integration/test_converge_with_swap_and_drop.py
  11. +1
    -1
      imperative/python/test/unit/core/test_interpreter.py
  12. +12
    -1
      imperative/python/test/unit/functional/test_elemwise.py
  13. +12
    -12
      imperative/python/test/unit/functional/test_functional.py
  14. +1
    -4
      imperative/python/test/unit/optimizer/test_clip_grad.py

+ 6
- 4
imperative/python/megengine/core/tensor/array_method.py View File

@@ -13,7 +13,7 @@ from typing import Union
import numpy as np

from .._imperative_rt.common import CompNode
from .._imperative_rt.core2 import SymbolVar, Tensor, apply
from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion
from ..ops import builtin
from . import amp
from .indexing import getitem, setitem
@@ -81,7 +81,11 @@ def _matmul(inp1, inp2):
inp1, inp2 = cast_tensors(inp1, inp2)
else:
compute_mode = "default"
inp1, inp2 = convert_inputs(inp1, inp2)
dtype = dtype_promotion(inp1, inp2)
if inp1.dtype != dtype:
inp1 = inp1.astype(dtype)
if inp2.dtype != dtype:
inp2 = inp2.astype(dtype)
op = builtin.MatrixMul(
transposeA=False, transposeB=False, compute_mode=compute_mode, format="default"
)
@@ -91,7 +95,6 @@ def _matmul(inp1, inp2):

def _transpose(data, axes):
op = builtin.Dimshuffle(axes)
(data,) = convert_inputs(data)
(result,) = apply(op, data)
return result

@@ -201,7 +204,6 @@ def _remove_axis(inp: Tensor, axis) -> Tensor:
def _reduce(mode):
def f(self, axis=None, keepdims: bool = False):
data = self
(data,) = convert_inputs(data)
if mode == "mean":
data = data.astype("float32")
elif self.dtype == np.bool_:


+ 1
- 1
imperative/python/megengine/functional/elemwise.py View File

@@ -13,7 +13,7 @@ from ..core._imperative_rt.core2 import SymbolVar, apply
from ..core.ops import builtin
from ..core.ops.builtin import Elemwise
from ..core.tensor.array_method import _elwise
from ..core.tensor.utils import astype, convert_inputs
from ..core.tensor.utils import convert_inputs
from ..tensor import Tensor
from ..utils.deprecation import deprecated_func



+ 9
- 7
imperative/python/megengine/functional/math.py View File

@@ -10,16 +10,16 @@ import collections
import math
from typing import Optional, Sequence, Tuple, Union

from ..core._imperative_rt.core2 import apply
from ..core._imperative_rt.core2 import apply, dtype_promotion
from ..core._trace_option import use_symbolic_shape
from ..core.ops import builtin
from ..core.ops.special import Const
from ..core.tensor import amp
from ..core.tensor.utils import _normalize_axis, cast_tensors, convert_inputs, setscalar
from ..core.tensor.utils import _normalize_axis, cast_tensors, setscalar
from ..tensor import Tensor
from .debug_param import get_execution_strategy
from .elemwise import clip, exp, log, log1p
from .tensor import broadcast_to, concat, expand_dims, reshape, squeeze
from .elemwise import clip
from .tensor import broadcast_to, concat, expand_dims, squeeze

__all__ = [
"argmax",
@@ -816,10 +816,13 @@ def matmul(
compute_mode = "float32"
inp1, inp2 = cast_tensors(inp1, inp2)
else:
inp1, inp2 = convert_inputs(inp1, inp2)
dtype = dtype_promotion(inp1, inp2)
if inp1.dtype != dtype:
inp1 = inp1.astype(dtype)
if inp2.dtype != dtype:
inp2 = inp2.astype(dtype)

remove_row, remove_col = False, False

dim1, dim2 = inp1.ndim, inp2.ndim
# handle dim=1 cases, dot and matrix-vector multiplication
if dim1 == 1 and dim2 == 1:
@@ -931,7 +934,6 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor:

"""
op = builtin.Dot()
inp1, inp2 = convert_inputs(inp1, inp2)
assert (
inp1.ndim <= 1 and inp2.ndim <= 1
), "Input tensors for dot must be 1-dimensional or scalar"


+ 5
- 30
imperative/python/megengine/functional/nn.py View File

@@ -10,8 +10,6 @@
from typing import Optional, Sequence, Tuple, Union

from ..core._imperative_rt.core2 import apply
from ..core._imperative_rt.graph import VarNode
from ..core._trace_option import use_symbolic_shape
from ..core.ops import builtin
from ..core.ops.builtin import BatchNorm, Elemwise
from ..core.ops.special import Const
@@ -21,7 +19,6 @@ from ..core.tensor.utils import (
astensor1d,
astype,
cast_tensors,
convert_inputs,
convert_single_value,
setscalar,
)
@@ -33,18 +30,9 @@ from ..utils.deprecation import deprecated_func
from ..utils.tuple_function import _pair, _pair_nonzero, _triple, _triple_nonzero
from .debug_param import get_execution_strategy
from .distributed import all_reduce_sum
from .elemwise import _elwise, exp, floor, log, log1p, maximum, minimum
from .math import argsort, matmul, max, prod, sum
from .tensor import (
broadcast_to,
concat,
expand_dims,
full,
ones,
reshape,
squeeze,
zeros,
)
from .elemwise import _elwise, exp, log, log1p, maximum, minimum
from .math import matmul, max, sum
from .tensor import broadcast_to, concat, expand_dims, ones, squeeze, zeros

__all__ = [
"adaptive_avg_pool2d",
@@ -167,8 +155,6 @@ def conv1d(
if amp._enabled:
compute_mode = "float32"
inp, weight, bias = cast_tensors(inp, weight, bias)
else:
inp, weight = convert_inputs(inp, weight)

inp = expand_dims(inp, 3)
weight = expand_dims(weight, 3)
@@ -246,8 +232,6 @@ def conv2d(
if amp._enabled:
compute_mode = "float32"
inp, weight, bias = cast_tensors(inp, weight, bias)
else:
inp, weight = convert_inputs(inp, weight)

stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding)
@@ -304,7 +288,6 @@ def conv3d(
:return: output tensor.
"""
assert conv_mode.lower() == "cross_correlation"
inp, weight = convert_inputs(inp, weight)

D, H, W = 0, 1, 2

@@ -379,8 +362,6 @@ def conv_transpose2d(
if amp._enabled:
compute_mode = "float32"
inp, weight, bias = cast_tensors(inp, weight, bias)
else:
inp, weight = convert_inputs(inp, weight)

if groups != 1:
raise NotImplementedError("group transposed conv2d is not supported yet.")
@@ -454,7 +435,8 @@ def deformable_conv2d(
compute_mode = "float32"
inp, weight, offset, mask, bias = cast_tensors(inp, weight, offset, mask, bias)
else:
inp, weight, offset, mask = convert_inputs(inp, weight, offset, mask)
offset = offset.astype("float32")
mask = mask.astype("float32")

stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding)
@@ -493,7 +475,6 @@ def local_conv2d(
conv_mode.lower() == "cross_correlation"
or conv_mode.name == "CROSS_CORRELATION"
)
inp, weight = convert_inputs(inp, weight)

stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding)
@@ -539,8 +520,6 @@ def conv_transpose3d(
:param dilation: dilation of the 3D convolution operation. Default: 1
:return: output tensor.
"""
inp, weight = convert_inputs(inp, weight)

D, H, W = 0, 1, 2
pad = _triple(padding)
stride = _triple_nonzero(stride)
@@ -1078,10 +1057,6 @@ def batch_norm(
weight, bias, running_mean, running_var = cast_tensors(
weight, bias, running_mean, running_var, promote=True
)
elif compute_mode != "float32":
inp, weight, bias, running_mean, running_var = convert_inputs(
inp, weight, bias, running_mean, running_var
)
weight = make_full_if_none(weight, 1)
bias = make_full_if_none(bias, 0)



+ 14
- 16
imperative/python/megengine/functional/tensor.py View File

@@ -6,25 +6,18 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import math
from typing import Iterable, Optional, Sequence, Union

import numpy as np

from ..core._imperative_rt import CompNode
from ..core._imperative_rt.core2 import SymbolVar, apply
from ..core._imperative_rt.core2 import SymbolVar, apply, dtype_promotion
from ..core._wrap import as_device
from ..core.ops import builtin
from ..core.ops.builtin import Copy, Identity
from ..core.ops.special import Const
from ..core.tensor.array_method import _broadcast, _remove_axis
from ..core.tensor.utils import (
astensor1d,
convert_inputs,
convert_single_value,
dtype_promotion,
get_device,
)
from ..core.tensor.utils import astensor1d, convert_inputs, get_device
from ..device import get_default_device
from ..tensor import Tensor
from .elemwise import ceil, floor_div
@@ -288,6 +281,7 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
if len(inps) == 1:
return inps[0]

# FIXME: remove this convert_inputs
inps = convert_inputs(*inps, device=device)
if device is None:
device = get_device(inps)
@@ -640,6 +634,7 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:

.. testcode::

import numpy as np
from megengine import tensor
import megengine.functional as F
mask = tensor(np.array([[True, False], [False, True]], dtype=np.bool))
@@ -657,7 +652,6 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
[7. 4.]]
"""

x, y = convert_inputs(x, y)
if not isinstance(x, Tensor):
raise TypeError("input x must be a tensor")
if not isinstance(y, Tensor):
@@ -669,6 +663,12 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
if x.device != mask.device:
raise ValueError("ambiguous device: {} vs {}".format(x.device, mask.device))

dtype = dtype_promotion(x, y)
if x.dtype != dtype:
x = x.astype(dtype)
if y.dtype != dtype:
y = y.astype(dtype)

v0, index0 = cond_take(mask, x)
v1, index1 = cond_take(~mask, y)

@@ -1021,12 +1021,10 @@ def arange(
if stop is None:
start, stop = 0, start

if isinstance(start, Tensor):
start = start.astype("float32")
if isinstance(stop, Tensor):
stop = stop.astype("float32")
if isinstance(step, Tensor):
step = step.astype("float32")
start = Tensor(start, dtype="float32")
stop = Tensor(stop, dtype="float32")
step = Tensor(step, dtype="float32")

num = ceil((stop - start) / step)
stop = start + step * (num - 1)
result = linspace(start, stop, num, device=device)


+ 9
- 4
imperative/python/megengine/functional/vision.py View File

@@ -8,6 +8,8 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Iterable, Optional, Tuple, Union

import numpy as np

from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin
from ..core.tensor import megbrain_graph, utils
@@ -98,7 +100,6 @@ def roi_pooling(
output_shape = (output_shape, output_shape)

op = builtin.ROIPooling(mode=mode, scale=scale)
inp, rois = utils.convert_inputs(inp, rois)
result, _ = apply(
op, inp, rois, Tensor(output_shape, dtype="int32", device=inp.device)
)
@@ -187,6 +188,8 @@ def roi_align(
[0.1359 0.1359]]]

"""
if inp.dtype != np.float32:
inp = inp.astype(np.float32)
mode = mode.lower()
assert mode in ["max", "average"], "only max/average mode is supported"
if isinstance(output_shape, int):
@@ -207,7 +210,6 @@ def roi_align(
sample_height=sample_height,
sample_width=sample_width,
)
inp, rois = utils.convert_inputs(inp, rois)
result, *_ = apply(op, inp, rois)
return result

@@ -270,7 +272,7 @@ def nms(
max_output = boxes.shape[0]

op = builtin.NMSKeep(iou_thresh, max_output)
inp = utils.convert_inputs(boxes.reshape(1, -1, 4))
inp = (boxes.reshape(1, -1, 4),)
indices, count = apply(op, *inp)
indices = indices[0][: count[0]]
keep_inds = sorted_idx[indices]
@@ -442,10 +444,13 @@ def warp_perspective(
[ 9. 10.]]]]

"""
if inp.dtype == np.float32:
mat = mat.astype("float32")
if inp.dtype == np.float16:
inp = inp.astype("float32")
op = builtin.WarpPerspective(
imode=interp_mode, bmode=border_mode, format=format, border_val=border_val
)
inp, mat = utils.convert_inputs(inp, mat)
out_shape = astensor1d(out_shape, inp, dtype="int32", device=inp.device)
if mat_idx is not None:
mat_idx = astensor1d(mat_idx, inp, dtype="int32", device=inp.device)


+ 5
- 6
imperative/python/test/integration/test_bn.py View File

@@ -14,8 +14,7 @@ import megengine.autodiff as ad
import megengine.distributed as dist
import megengine.functional as F
import megengine.optimizer as optimizer
from megengine import Parameter, tensor
from megengine.distributed.helper import get_device_count_by_fork
from megengine import tensor
from megengine.jit import trace
from megengine.module import BatchNorm2d, Conv2d, Module, Sequential, SyncBatchNorm

@@ -88,7 +87,7 @@ def test_bn_no_track_stat():
optim = optimizer.SGD(m.parameters(), lr=1.0)
optim.clear_grad()

data = np.random.random((6, nchannel, 2, 2)).astype("float32")
data = tensor(np.random.random((6, nchannel, 2, 2)).astype("float32"))
with gm:
loss = m(data).sum()
gm.backward(loss)
@@ -110,7 +109,7 @@ def test_bn_no_track_stat2():
optim = optimizer.SGD(m.parameters(), lr=1.0)
optim.clear_grad()

data = np.random.random((6, nchannel, 2, 2)).astype("float32")
data = tensor(np.random.random((6, nchannel, 2, 2)).astype("float32"))
with gm:
loss = m(data).sum()
gm.backward(loss)
@@ -146,7 +145,7 @@ def test_trace_bn_forward_twice():
pred = net(inp)
return pred

x = np.ones((1, 1, 32, 32), dtype=np.float32)
x = tensor(np.ones((1, 1, 32, 32), dtype=np.float32))
y = train_bn(x, net=Simple())
np.testing.assert_equal(y.numpy(), 0)

@@ -194,5 +193,5 @@ def test_trace_several_syncbn(trace_mode):
def test_frozen_bn_no_affine():
nchannel = 3
m = BatchNorm2d(nchannel, freeze=True, affine=False)
data = megengine.Tensor(np.random.random((6, nchannel, 2, 2)).astype("float32"))
data = tensor(np.random.random((6, nchannel, 2, 2)).astype("float32"))
m(data).numpy()

+ 2
- 3
imperative/python/test/integration/test_converge.py View File

@@ -9,7 +9,6 @@
import itertools

import numpy as np
import pytest

import megengine as mge
import megengine.autodiff as ad
@@ -105,10 +104,10 @@ def test_training_converge():
xx, yy = np.meshgrid(x, x)
xx = xx.reshape((ngrid * ngrid, 1))
yy = yy.reshape((ngrid * ngrid, 1))
data = np.concatenate((xx, yy), axis=1).astype(np.float32)
data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32))

pred = infer(data).numpy()
precision = calculate_precision(data, pred)
precision = calculate_precision(data.numpy(), pred)
assert precision == 1.0, "Test precision must be high enough, get {}".format(
precision
)

+ 2
- 3
imperative/python/test/integration/test_converge_with_gradient_clip.py View File

@@ -9,7 +9,6 @@
import itertools

import numpy as np
import pytest

import megengine as mge
import megengine.autodiff as ad
@@ -110,10 +109,10 @@ def test_training_converge():
xx, yy = np.meshgrid(x, x)
xx = xx.reshape((ngrid * ngrid, 1))
yy = yy.reshape((ngrid * ngrid, 1))
data = np.concatenate((xx, yy), axis=1).astype(np.float32)
data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32))

pred = infer(data).numpy()
precision = calculate_precision(data, pred)
precision = calculate_precision(data.numpy(), pred)
print("precision=", precision)
assert precision == 1.0, "Test precision must be high enough, get {}".format(
precision


+ 2
- 3
imperative/python/test/integration/test_converge_with_swap_and_drop.py View File

@@ -9,7 +9,6 @@
import itertools

import numpy as np
import pytest

import megengine as mge
import megengine.autodiff as ad
@@ -118,10 +117,10 @@ def test_training_converge_with_swap_and_drop():
xx, yy = np.meshgrid(x, x)
xx = xx.reshape((ngrid * ngrid, 1))
yy = yy.reshape((ngrid * ngrid, 1))
data = np.concatenate((xx, yy), axis=1).astype(np.float32)
data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32))

pred = infer(Tensor(data)).numpy()
precision = calculate_precision(data, pred)
precision = calculate_precision(data.numpy(), pred)
assert precision == 1.0, "Test precision must be high enough, get {}".format(
precision
)


+ 1
- 1
imperative/python/test/unit/core/test_interpreter.py View File

@@ -36,7 +36,7 @@ def test_level1_infer_value():
def test_level1_infer_shape_with_unknown():
config_async_level(2)
a = mge.tensor([[1, 2, 2, 3]], dtype="float32")
b = mge.tensor([1, 1])
b = mge.tensor([1, 1], dtype="float32")
multi2 = mge.tensor(np.array([[2, 0], [0, 2]]), dtype="float32")
c = F.matmul(b, multi2)
# make DepType::SHAPE unknown


+ 12
- 1
imperative/python/test/unit/functional/test_elemwise.py View File

@@ -13,7 +13,7 @@ import megengine.functional as F
import megengine.functional.elemwise as elemwise
from megengine import tensor
from megengine.core.tensor import dtype
from megengine.functional.elemwise import Elemwise, _elwise
from megengine.functional.elemwise import Elemwise
from megengine.jit import trace


@@ -57,6 +57,17 @@ def test_multiply():
)


def test_div():
np.testing.assert_allclose(
F.div(tensor([3, 4]), 2).numpy(),
np.divide(np.array([3, 4], dtype=np.float32), 2),
)

np.testing.assert_allclose(
(tensor([3, 4]) / 2).numpy(), np.divide(np.array([3, 4], dtype=np.float32), 2),
)


def test_clamp():
"""Fix an issue when `lower` or `upper` is 0, it will be recognized as `False` and
`F.clip` will fall into wrong conditions unexpectedly.


+ 12
- 12
imperative/python/test/unit/functional/test_functional.py View File

@@ -456,9 +456,10 @@ def test_interpolate_fastpath():
np.testing.assert_equal(out.item(), np_x.mean())


def test_warp_perspective():
@pytest.mark.parametrize("dt", [np.float32, np.int8, np.uint8, np.float16])
def test_warp_perspective(dt):
inp_shape = (1, 1, 4, 4)
x = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
x = tensor(np.arange(16, dtype=dt).reshape(inp_shape))
M_shape = (1, 3, 3)
# M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1)
M = tensor(
@@ -467,14 +468,13 @@ def test_warp_perspective():
).reshape(M_shape)
)
outp = F.vision.warp_perspective(x, M, (2, 2))
np.testing.assert_equal(
outp.numpy(), np.array([[[[5.0, 6.0], [9.0, 10.0]]]], dtype=np.float32)
)
np.testing.assert_equal(outp.numpy(), np.array([[[[5, 6], [9, 10]]]], dtype=dt))


def test_warp_perspective_mat_idx():
@pytest.mark.parametrize("dt", [np.float32, np.int8, np.uint8, np.float16])
def test_warp_perspective_mat_idx(dt):
inp_shape = (2, 1, 4, 4)
x = tensor(np.arange(32, dtype=np.float32).reshape(inp_shape))
x = tensor(np.arange(32, dtype=dt).reshape(inp_shape))
M_shape = (1, 3, 3)
# M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1)
M = tensor(
@@ -488,12 +488,12 @@ def test_warp_perspective_mat_idx():
outp.numpy(),
np.array(
[
[[[5.0, 6.0], [9.0, 10.0]]],
[[[21.0, 22.0], [25.0, 26.0]]],
[[[21.0, 22.0], [25.0, 26.0]]],
[[[5.0, 6.0], [9.0, 10.0]]],
[[[5, 6], [9, 10]]],
[[[21, 22], [25, 26]]],
[[[21, 22], [25, 26]]],
[[[5, 6], [9, 10]]],
],
dtype=np.float32,
dtype=dt,
),
)



+ 1
- 4
imperative/python/test/unit/optimizer/test_clip_grad.py View File

@@ -5,11 +5,8 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import platform
import weakref

import numpy as np
import pytest

import megengine as mge
import megengine.autodiff as ad
@@ -65,7 +62,7 @@ def test_clip_grad_value():
gm = ad.GradManager().attach(net.parameters())
opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9)
with gm:
y = net(x)
y = net(mge.tensor(x))
y = y.mean()
gm.backward(y)
save_grad_value(net)


Loading…
Cancel
Save