From 0d4568d6f87e940f0578784b244deb2847b62b42 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 9 Oct 2020 23:30:59 +0800 Subject: [PATCH] feat(mge): rename broadcast -> broadcast_to GitOrigin-RevId: 82f46ad2c22d3b0e3e33783107825669487fc3a1 --- .../python/megengine/autodiff/grad_manager.py | 4 +++- .../megengine/core/autodiff/builtin_op_utils.py | 4 ++-- .../python/megengine/core/tensor/indexing.py | 2 +- .../python/megengine/core/tensor/tensor_wrapper.py | 3 ++- imperative/python/megengine/functional/nn.py | 27 ++++++++++++++-------- imperative/python/megengine/functional/tensor.py | 15 ++++++------ imperative/python/megengine/functional/utils.py | 6 ++--- imperative/python/test/unit/core/test_autodiff.py | 2 +- .../python/test/unit/functional/test_tensor.py | 8 +++---- imperative/python/test/unit/test_tracing.py | 2 +- 10 files changed, 42 insertions(+), 31 deletions(-) diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index 6f68a2a2..001c9f9d 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -100,6 +100,8 @@ class GradManager: :param ys: outputs of forward operators, e.g., the loss tensor :param dys: derivatives of ys """ + from ..functional import ones_like + global backwarding_grad_manager cache = backwarding_grad_manager backwarding_grad_manager = self @@ -113,7 +115,7 @@ class GradManager: if not isinstance(ys, (tuple, list)): ys = [ys] if dys is None: - dys = [tensor(1.0).broadcast(y.shape) for y in ys] + dys = [ones_like(y) for y in ys] if not isinstance(dys, (tuple, list)): dys = [dys] try: diff --git a/imperative/python/megengine/core/autodiff/builtin_op_utils.py b/imperative/python/megengine/core/autodiff/builtin_op_utils.py index 51f54194..6ed12afb 100644 --- a/imperative/python/megengine/core/autodiff/builtin_op_utils.py +++ b/imperative/python/megengine/core/autodiff/builtin_op_utils.py @@ -160,7 +160,7 @@ def subtensor_grad_fn(op, inputs, outputs, input_requires_grad): def make_grad(grad_op, dy): grad = ( TensorWrapper(0, dtype=dy.dtype, device=dy.device) - .broadcast(TensorWrapper(input_shape)) + ._broadcast(TensorWrapper(input_shape)) .__wrapped__ ) (dx,) = apply(grad_op, grad, dy, *params) @@ -186,7 +186,7 @@ def indexingMultiAxisVec_grad_fn(op, inputs, outputs, input_requires_grad): def make_grad(grad_op, dy): grad = ( TensorWrapper(0, dtype=dy.dtype, device=dy.device) - .broadcast(TensorWrapper(input_shape)) + ._broadcast(TensorWrapper(input_shape)) .__wrapped__ ) (dx,) = apply(grad_op, grad, dy, *params) diff --git a/imperative/python/megengine/core/tensor/indexing.py b/imperative/python/megengine/core/tensor/indexing.py index cbbc61a1..e4bd8377 100644 --- a/imperative/python/megengine/core/tensor/indexing.py +++ b/imperative/python/megengine/core/tensor/indexing.py @@ -267,7 +267,7 @@ def setitem(tensor, index, value): value.shape, tmp_result.shape ) ) - value = value.broadcast(tmp_result.shape) + value = value._broadcast(tmp_result.shape) if use_subtensor: op = builtin.SetSubtensor(items=items) else: diff --git a/imperative/python/megengine/core/tensor/tensor_wrapper.py b/imperative/python/megengine/core/tensor/tensor_wrapper.py index 800146e0..47c54441 100644 --- a/imperative/python/megengine/core/tensor/tensor_wrapper.py +++ b/imperative/python/megengine/core/tensor/tensor_wrapper.py @@ -396,7 +396,8 @@ class ArrayMethodMixin(abc.ABC): def reshape(self, *args): return _reshape(self, _expand_args(args)) - def broadcast(self, *args): + # FIXME: remove this method + def _broadcast(self, *args): return _broadcast(self, _expand_args(args)) def transpose(self, *args): diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 6d15ff3f..1e62901d 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -23,7 +23,16 @@ from .debug_param import get_conv_execution_strategy from .distributed import all_reduce_sum from .elemwise import exp, floor, log, log1p, maximum, minimum, relu from .math import argsort, max, sum -from .tensor import broadcast, concat, expand_dims, full, ones, reshape, squeeze, zeros +from .tensor import ( + broadcast_to, + concat, + expand_dims, + full, + ones, + reshape, + squeeze, + zeros, +) from .types import _pair, _pair_nonzero __all__ = [ @@ -635,7 +644,7 @@ def batch_norm2d( def full_value(value): C = inp.shape[1] (x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) - return broadcast(x, [1, C, 1, 1]) + return broadcast_to(x, [1, C, 1, 1]) def expand_or_full(x, value): if x is None: @@ -754,7 +763,7 @@ def sync_batch_norm( if is_distributed(): # reduce all nodes' data to calculate mean and variance - reduce_size = broadcast(Tensor(reduce_size, dtype=_dtype), [1] * _ndim) + reduce_size = broadcast_to(Tensor(reduce_size, dtype=_dtype), [1] * _ndim) stat = concat( [reduce_size.astype(_dtype), channel_x1s, channel_x2s], axis=1 ) @@ -968,10 +977,10 @@ def matmul( if dim1 != dim2: if dim1 < dim2: shape1 = shape2[: dim2 - dim1] + shape1 - inp1 = inp1.broadcast(*shape1) + inp1 = broadcast_to(inp1, shape1) else: shape2 = shape1[: dim1 - dim2] + shape2 - inp2 = inp2.broadcast(*shape2) + inp2 = broadcast_to(inp2, shape2) reshaped_batch_size = 1 for i in shape1[:-2]: reshaped_batch_size *= i @@ -986,9 +995,9 @@ def matmul( shp = shape1[:-1] + shape2[-1:] elif dim1 == 3 or dim2 == 3: if dim2 < 3: - inp2 = inp2.broadcast(*(inp1.shape[:1] + inp2.shape)) + inp2 = broadcast_to(inp2, inp1.shape[:1] + inp2.shape) elif dim1 < 3: - inp1 = inp1.broadcast(*(inp2.shape[:1] + inp1.shape)) + inp1 = broadcast_to(inp1, inp2.shape[:1] + inp1.shape) op = builtin.BatchedMatrixMul( transposeA=transpose_a, transposeB=transpose_b, @@ -1205,7 +1214,7 @@ def interpolate( [row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)], axis=0, ).reshape(1, 3, 3) - weight = broadcast(weight, (inp.shape[0], 3, 3)) + weight = broadcast_to(weight, (inp.shape[0], 3, 3)) else: hscale = 1.0 * ih / oh wscale = 1.0 * iw / ow @@ -1221,7 +1230,7 @@ def interpolate( [row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)], axis=0, ).reshape(1, 3, 3) - weight = broadcast(weight, (inp.shape[0], 3, 3)) + weight = broadcast_to(weight, (inp.shape[0], 3, 3)) weight = weight.astype("float32") ret = warp_perspective(inp, weight, dsize, interp_mode="LINEAR") diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 6c4cc743..b2c8d170 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -19,7 +19,7 @@ from ..core.ops import builtin from ..core.ops._internal import param_defs as P from ..core.ops.special import Const from ..core.tensor.core import TensorBase, TensorWrapperBase, apply -from ..core.tensor.tensor_wrapper import _remove_axis +from ..core.tensor.tensor_wrapper import _broadcast, _remove_axis from ..core.tensor.utils import ( astensor1d, convert_inputs, @@ -33,7 +33,7 @@ from .elemwise import ceil __all__ = [ "arange", - "broadcast", + "broadcast_to", "concat", "cond_take", "expand_dims", @@ -104,7 +104,7 @@ def full(shape, value, dtype="float32", device=None): (x,) = Const(value, dtype=dtype, device=device)( Tensor(value, dtype=dtype, device=device) ) - return broadcast(x, shape) + return broadcast_to(x, shape) def ones(shape, dtype="float32", device=None): @@ -192,7 +192,7 @@ def identity(inp: Tensor) -> Tensor: return output -def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: +def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: """ Broadcasts a tensor to given shape. @@ -209,7 +209,7 @@ def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: import megengine.functional as F data = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) - out = F.broadcast(data, (4, 2, 3)) + out = F.broadcast_to(data, (4, 2, 3)) print(out.numpy()) Outputs: @@ -229,7 +229,7 @@ def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: [3. 4. 5.]]] """ - return inp.broadcast(shape) + return _broadcast(inp, shape) def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: @@ -395,8 +395,7 @@ def _get_idx(index, axis): 0, index.shape[i] - 1, index.shape[i], device=index.device, ) arange = ( - arange.reshape(*shape) - .broadcast(index.shape) + broadcast_to(arange.reshape(*shape), index.shape) .reshape(-1) .astype(np.int32) ) diff --git a/imperative/python/megengine/functional/utils.py b/imperative/python/megengine/functional/utils.py index 8b7359d5..8446f964 100644 --- a/imperative/python/megengine/functional/utils.py +++ b/imperative/python/megengine/functional/utils.py @@ -15,7 +15,7 @@ from ..core.ops.builtin import Copy from ..core.tensor import Tensor from ..core.tensor.core import apply from .math import topk as _topk -from .tensor import transpose as _transpose +from .tensor import broadcast_to, transpose def accuracy( @@ -54,8 +54,8 @@ def accuracy( _, pred = _topk(logits, k=max(topk), descending=True) accs = [] for k in topk: - correct = pred[:, :k].detach() == _transpose(target, (0, "x")).broadcast( - target.shape[0], k + correct = pred[:, :k].detach() == broadcast_to( + transpose(target, (0, "x")), (target.shape[0], k) ) accs.append(correct.astype(np.float32).sum() / target.shape[0]) if len(topk) == 1: # type: ignore[arg-type] diff --git a/imperative/python/test/unit/core/test_autodiff.py b/imperative/python/test/unit/core/test_autodiff.py index 7e017e78..a910da26 100644 --- a/imperative/python/test/unit/core/test_autodiff.py +++ b/imperative/python/test/unit/core/test_autodiff.py @@ -319,7 +319,7 @@ def test_Broadcast(): x = TensorWrapper(x_np) grad = Grad().wrt(x, callback=save_to(x)) - y = F.broadcast(x, (3, 3, 10)) + y = F.broadcast_to(x, (3, 3, 10)) grad(y, F.ones_like(y)) np.testing.assert_equal(np.ones((3, 3, 1), dtype=np.float32) * 10, x.grad.numpy()) diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index b2ebafce..0b02b78f 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -251,17 +251,17 @@ def test_broadcast(): {"input": [data1, output1_shape], "output": output1_shape}, {"input": [data2, output2_shape], "output": output2_shape}, ] - opr_test(cases, F.broadcast, compare_fn=compare_fn) + opr_test(cases, F.broadcast_to, compare_fn=compare_fn) x = F.ones((2, 1, 3)) with pytest.raises(ValueError): - F.broadcast(x, (2, 3, 4)) + F.broadcast_to(x, (2, 3, 4)) with pytest.raises(ValueError): - F.broadcast(x, (4, 1, 3)) + F.broadcast_to(x, (4, 1, 3)) with pytest.raises(ValueError): - F.broadcast(x, (1, 3)) + F.broadcast_to(x, (1, 3)) def test_utils_astensor1d(): diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index c9d7abd2..d54206d9 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -351,7 +351,7 @@ def test_trace_broadcast(): @trace(symbolic=symbolic, capture_as_const=True) def f(x): - y = x.broadcast((3, 4, 5)) + y = F.broadcast_to(x, (3, 4, 5)) return y f(x1)