Browse Source

feat(mge): rename broadcast -> broadcast_to

GitOrigin-RevId: 82f46ad2c2
release-1.1
Megvii Engine Team 4 years ago
parent
commit
0d4568d6f8
10 changed files with 42 additions and 31 deletions
  1. +3
    -1
      imperative/python/megengine/autodiff/grad_manager.py
  2. +2
    -2
      imperative/python/megengine/core/autodiff/builtin_op_utils.py
  3. +1
    -1
      imperative/python/megengine/core/tensor/indexing.py
  4. +2
    -1
      imperative/python/megengine/core/tensor/tensor_wrapper.py
  5. +18
    -9
      imperative/python/megengine/functional/nn.py
  6. +7
    -8
      imperative/python/megengine/functional/tensor.py
  7. +3
    -3
      imperative/python/megengine/functional/utils.py
  8. +1
    -1
      imperative/python/test/unit/core/test_autodiff.py
  9. +4
    -4
      imperative/python/test/unit/functional/test_tensor.py
  10. +1
    -1
      imperative/python/test/unit/test_tracing.py

+ 3
- 1
imperative/python/megengine/autodiff/grad_manager.py View File

@@ -100,6 +100,8 @@ class GradManager:
:param ys: outputs of forward operators, e.g., the loss tensor :param ys: outputs of forward operators, e.g., the loss tensor
:param dys: derivatives of ys :param dys: derivatives of ys
""" """
from ..functional import ones_like

global backwarding_grad_manager global backwarding_grad_manager
cache = backwarding_grad_manager cache = backwarding_grad_manager
backwarding_grad_manager = self backwarding_grad_manager = self
@@ -113,7 +115,7 @@ class GradManager:
if not isinstance(ys, (tuple, list)): if not isinstance(ys, (tuple, list)):
ys = [ys] ys = [ys]
if dys is None: if dys is None:
dys = [tensor(1.0).broadcast(y.shape) for y in ys]
dys = [ones_like(y) for y in ys]
if not isinstance(dys, (tuple, list)): if not isinstance(dys, (tuple, list)):
dys = [dys] dys = [dys]
try: try:


+ 2
- 2
imperative/python/megengine/core/autodiff/builtin_op_utils.py View File

@@ -160,7 +160,7 @@ def subtensor_grad_fn(op, inputs, outputs, input_requires_grad):
def make_grad(grad_op, dy): def make_grad(grad_op, dy):
grad = ( grad = (
TensorWrapper(0, dtype=dy.dtype, device=dy.device) TensorWrapper(0, dtype=dy.dtype, device=dy.device)
.broadcast(TensorWrapper(input_shape))
._broadcast(TensorWrapper(input_shape))
.__wrapped__ .__wrapped__
) )
(dx,) = apply(grad_op, grad, dy, *params) (dx,) = apply(grad_op, grad, dy, *params)
@@ -186,7 +186,7 @@ def indexingMultiAxisVec_grad_fn(op, inputs, outputs, input_requires_grad):
def make_grad(grad_op, dy): def make_grad(grad_op, dy):
grad = ( grad = (
TensorWrapper(0, dtype=dy.dtype, device=dy.device) TensorWrapper(0, dtype=dy.dtype, device=dy.device)
.broadcast(TensorWrapper(input_shape))
._broadcast(TensorWrapper(input_shape))
.__wrapped__ .__wrapped__
) )
(dx,) = apply(grad_op, grad, dy, *params) (dx,) = apply(grad_op, grad, dy, *params)


+ 1
- 1
imperative/python/megengine/core/tensor/indexing.py View File

@@ -267,7 +267,7 @@ def setitem(tensor, index, value):
value.shape, tmp_result.shape value.shape, tmp_result.shape
) )
) )
value = value.broadcast(tmp_result.shape)
value = value._broadcast(tmp_result.shape)
if use_subtensor: if use_subtensor:
op = builtin.SetSubtensor(items=items) op = builtin.SetSubtensor(items=items)
else: else:


+ 2
- 1
imperative/python/megengine/core/tensor/tensor_wrapper.py View File

@@ -396,7 +396,8 @@ class ArrayMethodMixin(abc.ABC):
def reshape(self, *args): def reshape(self, *args):
return _reshape(self, _expand_args(args)) return _reshape(self, _expand_args(args))


def broadcast(self, *args):
# FIXME: remove this method
def _broadcast(self, *args):
return _broadcast(self, _expand_args(args)) return _broadcast(self, _expand_args(args))


def transpose(self, *args): def transpose(self, *args):


+ 18
- 9
imperative/python/megengine/functional/nn.py View File

@@ -23,7 +23,16 @@ from .debug_param import get_conv_execution_strategy
from .distributed import all_reduce_sum from .distributed import all_reduce_sum
from .elemwise import exp, floor, log, log1p, maximum, minimum, relu from .elemwise import exp, floor, log, log1p, maximum, minimum, relu
from .math import argsort, max, sum from .math import argsort, max, sum
from .tensor import broadcast, concat, expand_dims, full, ones, reshape, squeeze, zeros
from .tensor import (
broadcast_to,
concat,
expand_dims,
full,
ones,
reshape,
squeeze,
zeros,
)
from .types import _pair, _pair_nonzero from .types import _pair, _pair_nonzero


__all__ = [ __all__ = [
@@ -635,7 +644,7 @@ def batch_norm2d(
def full_value(value): def full_value(value):
C = inp.shape[1] C = inp.shape[1]
(x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) (x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp)
return broadcast(x, [1, C, 1, 1])
return broadcast_to(x, [1, C, 1, 1])


def expand_or_full(x, value): def expand_or_full(x, value):
if x is None: if x is None:
@@ -754,7 +763,7 @@ def sync_batch_norm(


if is_distributed(): if is_distributed():
# reduce all nodes' data to calculate mean and variance # reduce all nodes' data to calculate mean and variance
reduce_size = broadcast(Tensor(reduce_size, dtype=_dtype), [1] * _ndim)
reduce_size = broadcast_to(Tensor(reduce_size, dtype=_dtype), [1] * _ndim)
stat = concat( stat = concat(
[reduce_size.astype(_dtype), channel_x1s, channel_x2s], axis=1 [reduce_size.astype(_dtype), channel_x1s, channel_x2s], axis=1
) )
@@ -968,10 +977,10 @@ def matmul(
if dim1 != dim2: if dim1 != dim2:
if dim1 < dim2: if dim1 < dim2:
shape1 = shape2[: dim2 - dim1] + shape1 shape1 = shape2[: dim2 - dim1] + shape1
inp1 = inp1.broadcast(*shape1)
inp1 = broadcast_to(inp1, shape1)
else: else:
shape2 = shape1[: dim1 - dim2] + shape2 shape2 = shape1[: dim1 - dim2] + shape2
inp2 = inp2.broadcast(*shape2)
inp2 = broadcast_to(inp2, shape2)
reshaped_batch_size = 1 reshaped_batch_size = 1
for i in shape1[:-2]: for i in shape1[:-2]:
reshaped_batch_size *= i reshaped_batch_size *= i
@@ -986,9 +995,9 @@ def matmul(
shp = shape1[:-1] + shape2[-1:] shp = shape1[:-1] + shape2[-1:]
elif dim1 == 3 or dim2 == 3: elif dim1 == 3 or dim2 == 3:
if dim2 < 3: if dim2 < 3:
inp2 = inp2.broadcast(*(inp1.shape[:1] + inp2.shape))
inp2 = broadcast_to(inp2, inp1.shape[:1] + inp2.shape)
elif dim1 < 3: elif dim1 < 3:
inp1 = inp1.broadcast(*(inp2.shape[:1] + inp1.shape))
inp1 = broadcast_to(inp1, inp2.shape[:1] + inp1.shape)
op = builtin.BatchedMatrixMul( op = builtin.BatchedMatrixMul(
transposeA=transpose_a, transposeA=transpose_a,
transposeB=transpose_b, transposeB=transpose_b,
@@ -1205,7 +1214,7 @@ def interpolate(
[row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)], [row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)],
axis=0, axis=0,
).reshape(1, 3, 3) ).reshape(1, 3, 3)
weight = broadcast(weight, (inp.shape[0], 3, 3))
weight = broadcast_to(weight, (inp.shape[0], 3, 3))
else: else:
hscale = 1.0 * ih / oh hscale = 1.0 * ih / oh
wscale = 1.0 * iw / ow wscale = 1.0 * iw / ow
@@ -1221,7 +1230,7 @@ def interpolate(
[row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)], [row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)],
axis=0, axis=0,
).reshape(1, 3, 3) ).reshape(1, 3, 3)
weight = broadcast(weight, (inp.shape[0], 3, 3))
weight = broadcast_to(weight, (inp.shape[0], 3, 3))


weight = weight.astype("float32") weight = weight.astype("float32")
ret = warp_perspective(inp, weight, dsize, interp_mode="LINEAR") ret = warp_perspective(inp, weight, dsize, interp_mode="LINEAR")


+ 7
- 8
imperative/python/megengine/functional/tensor.py View File

@@ -19,7 +19,7 @@ from ..core.ops import builtin
from ..core.ops._internal import param_defs as P from ..core.ops._internal import param_defs as P
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
from ..core.tensor.tensor_wrapper import _remove_axis
from ..core.tensor.tensor_wrapper import _broadcast, _remove_axis
from ..core.tensor.utils import ( from ..core.tensor.utils import (
astensor1d, astensor1d,
convert_inputs, convert_inputs,
@@ -33,7 +33,7 @@ from .elemwise import ceil


__all__ = [ __all__ = [
"arange", "arange",
"broadcast",
"broadcast_to",
"concat", "concat",
"cond_take", "cond_take",
"expand_dims", "expand_dims",
@@ -104,7 +104,7 @@ def full(shape, value, dtype="float32", device=None):
(x,) = Const(value, dtype=dtype, device=device)( (x,) = Const(value, dtype=dtype, device=device)(
Tensor(value, dtype=dtype, device=device) Tensor(value, dtype=dtype, device=device)
) )
return broadcast(x, shape)
return broadcast_to(x, shape)




def ones(shape, dtype="float32", device=None): def ones(shape, dtype="float32", device=None):
@@ -192,7 +192,7 @@ def identity(inp: Tensor) -> Tensor:
return output return output




def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
""" """
Broadcasts a tensor to given shape. Broadcasts a tensor to given shape.


@@ -209,7 +209,7 @@ def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
import megengine.functional as F import megengine.functional as F


data = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) data = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
out = F.broadcast(data, (4, 2, 3))
out = F.broadcast_to(data, (4, 2, 3))
print(out.numpy()) print(out.numpy())


Outputs: Outputs:
@@ -229,7 +229,7 @@ def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
[3. 4. 5.]]] [3. 4. 5.]]]


""" """
return inp.broadcast(shape)
return _broadcast(inp, shape)




def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
@@ -395,8 +395,7 @@ def _get_idx(index, axis):
0, index.shape[i] - 1, index.shape[i], device=index.device, 0, index.shape[i] - 1, index.shape[i], device=index.device,
) )
arange = ( arange = (
arange.reshape(*shape)
.broadcast(index.shape)
broadcast_to(arange.reshape(*shape), index.shape)
.reshape(-1) .reshape(-1)
.astype(np.int32) .astype(np.int32)
) )


+ 3
- 3
imperative/python/megengine/functional/utils.py View File

@@ -15,7 +15,7 @@ from ..core.ops.builtin import Copy
from ..core.tensor import Tensor from ..core.tensor import Tensor
from ..core.tensor.core import apply from ..core.tensor.core import apply
from .math import topk as _topk from .math import topk as _topk
from .tensor import transpose as _transpose
from .tensor import broadcast_to, transpose




def accuracy( def accuracy(
@@ -54,8 +54,8 @@ def accuracy(
_, pred = _topk(logits, k=max(topk), descending=True) _, pred = _topk(logits, k=max(topk), descending=True)
accs = [] accs = []
for k in topk: for k in topk:
correct = pred[:, :k].detach() == _transpose(target, (0, "x")).broadcast(
target.shape[0], k
correct = pred[:, :k].detach() == broadcast_to(
transpose(target, (0, "x")), (target.shape[0], k)
) )
accs.append(correct.astype(np.float32).sum() / target.shape[0]) accs.append(correct.astype(np.float32).sum() / target.shape[0])
if len(topk) == 1: # type: ignore[arg-type] if len(topk) == 1: # type: ignore[arg-type]


+ 1
- 1
imperative/python/test/unit/core/test_autodiff.py View File

@@ -319,7 +319,7 @@ def test_Broadcast():
x = TensorWrapper(x_np) x = TensorWrapper(x_np)


grad = Grad().wrt(x, callback=save_to(x)) grad = Grad().wrt(x, callback=save_to(x))
y = F.broadcast(x, (3, 3, 10))
y = F.broadcast_to(x, (3, 3, 10))


grad(y, F.ones_like(y)) grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((3, 3, 1), dtype=np.float32) * 10, x.grad.numpy()) np.testing.assert_equal(np.ones((3, 3, 1), dtype=np.float32) * 10, x.grad.numpy())


+ 4
- 4
imperative/python/test/unit/functional/test_tensor.py View File

@@ -251,17 +251,17 @@ def test_broadcast():
{"input": [data1, output1_shape], "output": output1_shape}, {"input": [data1, output1_shape], "output": output1_shape},
{"input": [data2, output2_shape], "output": output2_shape}, {"input": [data2, output2_shape], "output": output2_shape},
] ]
opr_test(cases, F.broadcast, compare_fn=compare_fn)
opr_test(cases, F.broadcast_to, compare_fn=compare_fn)


x = F.ones((2, 1, 3)) x = F.ones((2, 1, 3))
with pytest.raises(ValueError): with pytest.raises(ValueError):
F.broadcast(x, (2, 3, 4))
F.broadcast_to(x, (2, 3, 4))


with pytest.raises(ValueError): with pytest.raises(ValueError):
F.broadcast(x, (4, 1, 3))
F.broadcast_to(x, (4, 1, 3))


with pytest.raises(ValueError): with pytest.raises(ValueError):
F.broadcast(x, (1, 3))
F.broadcast_to(x, (1, 3))




def test_utils_astensor1d(): def test_utils_astensor1d():


+ 1
- 1
imperative/python/test/unit/test_tracing.py View File

@@ -351,7 +351,7 @@ def test_trace_broadcast():


@trace(symbolic=symbolic, capture_as_const=True) @trace(symbolic=symbolic, capture_as_const=True)
def f(x): def f(x):
y = x.broadcast((3, 4, 5))
y = F.broadcast_to(x, (3, 4, 5))
return y return y


f(x1) f(x1)


Loading…
Cancel
Save