@@ -100,6 +100,8 @@ class GradManager: | |||||
:param ys: outputs of forward operators, e.g., the loss tensor | :param ys: outputs of forward operators, e.g., the loss tensor | ||||
:param dys: derivatives of ys | :param dys: derivatives of ys | ||||
""" | """ | ||||
from ..functional import ones_like | |||||
global backwarding_grad_manager | global backwarding_grad_manager | ||||
cache = backwarding_grad_manager | cache = backwarding_grad_manager | ||||
backwarding_grad_manager = self | backwarding_grad_manager = self | ||||
@@ -113,7 +115,7 @@ class GradManager: | |||||
if not isinstance(ys, (tuple, list)): | if not isinstance(ys, (tuple, list)): | ||||
ys = [ys] | ys = [ys] | ||||
if dys is None: | if dys is None: | ||||
dys = [tensor(1.0).broadcast(y.shape) for y in ys] | |||||
dys = [ones_like(y) for y in ys] | |||||
if not isinstance(dys, (tuple, list)): | if not isinstance(dys, (tuple, list)): | ||||
dys = [dys] | dys = [dys] | ||||
try: | try: | ||||
@@ -160,7 +160,7 @@ def subtensor_grad_fn(op, inputs, outputs, input_requires_grad): | |||||
def make_grad(grad_op, dy): | def make_grad(grad_op, dy): | ||||
grad = ( | grad = ( | ||||
TensorWrapper(0, dtype=dy.dtype, device=dy.device) | TensorWrapper(0, dtype=dy.dtype, device=dy.device) | ||||
.broadcast(TensorWrapper(input_shape)) | |||||
._broadcast(TensorWrapper(input_shape)) | |||||
.__wrapped__ | .__wrapped__ | ||||
) | ) | ||||
(dx,) = apply(grad_op, grad, dy, *params) | (dx,) = apply(grad_op, grad, dy, *params) | ||||
@@ -186,7 +186,7 @@ def indexingMultiAxisVec_grad_fn(op, inputs, outputs, input_requires_grad): | |||||
def make_grad(grad_op, dy): | def make_grad(grad_op, dy): | ||||
grad = ( | grad = ( | ||||
TensorWrapper(0, dtype=dy.dtype, device=dy.device) | TensorWrapper(0, dtype=dy.dtype, device=dy.device) | ||||
.broadcast(TensorWrapper(input_shape)) | |||||
._broadcast(TensorWrapper(input_shape)) | |||||
.__wrapped__ | .__wrapped__ | ||||
) | ) | ||||
(dx,) = apply(grad_op, grad, dy, *params) | (dx,) = apply(grad_op, grad, dy, *params) | ||||
@@ -267,7 +267,7 @@ def setitem(tensor, index, value): | |||||
value.shape, tmp_result.shape | value.shape, tmp_result.shape | ||||
) | ) | ||||
) | ) | ||||
value = value.broadcast(tmp_result.shape) | |||||
value = value._broadcast(tmp_result.shape) | |||||
if use_subtensor: | if use_subtensor: | ||||
op = builtin.SetSubtensor(items=items) | op = builtin.SetSubtensor(items=items) | ||||
else: | else: | ||||
@@ -396,7 +396,8 @@ class ArrayMethodMixin(abc.ABC): | |||||
def reshape(self, *args): | def reshape(self, *args): | ||||
return _reshape(self, _expand_args(args)) | return _reshape(self, _expand_args(args)) | ||||
def broadcast(self, *args): | |||||
# FIXME: remove this method | |||||
def _broadcast(self, *args): | |||||
return _broadcast(self, _expand_args(args)) | return _broadcast(self, _expand_args(args)) | ||||
def transpose(self, *args): | def transpose(self, *args): | ||||
@@ -23,7 +23,16 @@ from .debug_param import get_conv_execution_strategy | |||||
from .distributed import all_reduce_sum | from .distributed import all_reduce_sum | ||||
from .elemwise import exp, floor, log, log1p, maximum, minimum, relu | from .elemwise import exp, floor, log, log1p, maximum, minimum, relu | ||||
from .math import argsort, max, sum | from .math import argsort, max, sum | ||||
from .tensor import broadcast, concat, expand_dims, full, ones, reshape, squeeze, zeros | |||||
from .tensor import ( | |||||
broadcast_to, | |||||
concat, | |||||
expand_dims, | |||||
full, | |||||
ones, | |||||
reshape, | |||||
squeeze, | |||||
zeros, | |||||
) | |||||
from .types import _pair, _pair_nonzero | from .types import _pair, _pair_nonzero | ||||
__all__ = [ | __all__ = [ | ||||
@@ -635,7 +644,7 @@ def batch_norm2d( | |||||
def full_value(value): | def full_value(value): | ||||
C = inp.shape[1] | C = inp.shape[1] | ||||
(x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) | (x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) | ||||
return broadcast(x, [1, C, 1, 1]) | |||||
return broadcast_to(x, [1, C, 1, 1]) | |||||
def expand_or_full(x, value): | def expand_or_full(x, value): | ||||
if x is None: | if x is None: | ||||
@@ -754,7 +763,7 @@ def sync_batch_norm( | |||||
if is_distributed(): | if is_distributed(): | ||||
# reduce all nodes' data to calculate mean and variance | # reduce all nodes' data to calculate mean and variance | ||||
reduce_size = broadcast(Tensor(reduce_size, dtype=_dtype), [1] * _ndim) | |||||
reduce_size = broadcast_to(Tensor(reduce_size, dtype=_dtype), [1] * _ndim) | |||||
stat = concat( | stat = concat( | ||||
[reduce_size.astype(_dtype), channel_x1s, channel_x2s], axis=1 | [reduce_size.astype(_dtype), channel_x1s, channel_x2s], axis=1 | ||||
) | ) | ||||
@@ -968,10 +977,10 @@ def matmul( | |||||
if dim1 != dim2: | if dim1 != dim2: | ||||
if dim1 < dim2: | if dim1 < dim2: | ||||
shape1 = shape2[: dim2 - dim1] + shape1 | shape1 = shape2[: dim2 - dim1] + shape1 | ||||
inp1 = inp1.broadcast(*shape1) | |||||
inp1 = broadcast_to(inp1, shape1) | |||||
else: | else: | ||||
shape2 = shape1[: dim1 - dim2] + shape2 | shape2 = shape1[: dim1 - dim2] + shape2 | ||||
inp2 = inp2.broadcast(*shape2) | |||||
inp2 = broadcast_to(inp2, shape2) | |||||
reshaped_batch_size = 1 | reshaped_batch_size = 1 | ||||
for i in shape1[:-2]: | for i in shape1[:-2]: | ||||
reshaped_batch_size *= i | reshaped_batch_size *= i | ||||
@@ -986,9 +995,9 @@ def matmul( | |||||
shp = shape1[:-1] + shape2[-1:] | shp = shape1[:-1] + shape2[-1:] | ||||
elif dim1 == 3 or dim2 == 3: | elif dim1 == 3 or dim2 == 3: | ||||
if dim2 < 3: | if dim2 < 3: | ||||
inp2 = inp2.broadcast(*(inp1.shape[:1] + inp2.shape)) | |||||
inp2 = broadcast_to(inp2, inp1.shape[:1] + inp2.shape) | |||||
elif dim1 < 3: | elif dim1 < 3: | ||||
inp1 = inp1.broadcast(*(inp2.shape[:1] + inp1.shape)) | |||||
inp1 = broadcast_to(inp1, inp2.shape[:1] + inp1.shape) | |||||
op = builtin.BatchedMatrixMul( | op = builtin.BatchedMatrixMul( | ||||
transposeA=transpose_a, | transposeA=transpose_a, | ||||
transposeB=transpose_b, | transposeB=transpose_b, | ||||
@@ -1205,7 +1214,7 @@ def interpolate( | |||||
[row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)], | [row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)], | ||||
axis=0, | axis=0, | ||||
).reshape(1, 3, 3) | ).reshape(1, 3, 3) | ||||
weight = broadcast(weight, (inp.shape[0], 3, 3)) | |||||
weight = broadcast_to(weight, (inp.shape[0], 3, 3)) | |||||
else: | else: | ||||
hscale = 1.0 * ih / oh | hscale = 1.0 * ih / oh | ||||
wscale = 1.0 * iw / ow | wscale = 1.0 * iw / ow | ||||
@@ -1221,7 +1230,7 @@ def interpolate( | |||||
[row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)], | [row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)], | ||||
axis=0, | axis=0, | ||||
).reshape(1, 3, 3) | ).reshape(1, 3, 3) | ||||
weight = broadcast(weight, (inp.shape[0], 3, 3)) | |||||
weight = broadcast_to(weight, (inp.shape[0], 3, 3)) | |||||
weight = weight.astype("float32") | weight = weight.astype("float32") | ||||
ret = warp_perspective(inp, weight, dsize, interp_mode="LINEAR") | ret = warp_perspective(inp, weight, dsize, interp_mode="LINEAR") | ||||
@@ -19,7 +19,7 @@ from ..core.ops import builtin | |||||
from ..core.ops._internal import param_defs as P | from ..core.ops._internal import param_defs as P | ||||
from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | ||||
from ..core.tensor.tensor_wrapper import _remove_axis | |||||
from ..core.tensor.tensor_wrapper import _broadcast, _remove_axis | |||||
from ..core.tensor.utils import ( | from ..core.tensor.utils import ( | ||||
astensor1d, | astensor1d, | ||||
convert_inputs, | convert_inputs, | ||||
@@ -33,7 +33,7 @@ from .elemwise import ceil | |||||
__all__ = [ | __all__ = [ | ||||
"arange", | "arange", | ||||
"broadcast", | |||||
"broadcast_to", | |||||
"concat", | "concat", | ||||
"cond_take", | "cond_take", | ||||
"expand_dims", | "expand_dims", | ||||
@@ -104,7 +104,7 @@ def full(shape, value, dtype="float32", device=None): | |||||
(x,) = Const(value, dtype=dtype, device=device)( | (x,) = Const(value, dtype=dtype, device=device)( | ||||
Tensor(value, dtype=dtype, device=device) | Tensor(value, dtype=dtype, device=device) | ||||
) | ) | ||||
return broadcast(x, shape) | |||||
return broadcast_to(x, shape) | |||||
def ones(shape, dtype="float32", device=None): | def ones(shape, dtype="float32", device=None): | ||||
@@ -192,7 +192,7 @@ def identity(inp: Tensor) -> Tensor: | |||||
return output | return output | ||||
def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | |||||
def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | |||||
""" | """ | ||||
Broadcasts a tensor to given shape. | Broadcasts a tensor to given shape. | ||||
@@ -209,7 +209,7 @@ def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | |||||
import megengine.functional as F | import megengine.functional as F | ||||
data = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | data = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | ||||
out = F.broadcast(data, (4, 2, 3)) | |||||
out = F.broadcast_to(data, (4, 2, 3)) | |||||
print(out.numpy()) | print(out.numpy()) | ||||
Outputs: | Outputs: | ||||
@@ -229,7 +229,7 @@ def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | |||||
[3. 4. 5.]]] | [3. 4. 5.]]] | ||||
""" | """ | ||||
return inp.broadcast(shape) | |||||
return _broadcast(inp, shape) | |||||
def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: | def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: | ||||
@@ -395,8 +395,7 @@ def _get_idx(index, axis): | |||||
0, index.shape[i] - 1, index.shape[i], device=index.device, | 0, index.shape[i] - 1, index.shape[i], device=index.device, | ||||
) | ) | ||||
arange = ( | arange = ( | ||||
arange.reshape(*shape) | |||||
.broadcast(index.shape) | |||||
broadcast_to(arange.reshape(*shape), index.shape) | |||||
.reshape(-1) | .reshape(-1) | ||||
.astype(np.int32) | .astype(np.int32) | ||||
) | ) | ||||
@@ -15,7 +15,7 @@ from ..core.ops.builtin import Copy | |||||
from ..core.tensor import Tensor | from ..core.tensor import Tensor | ||||
from ..core.tensor.core import apply | from ..core.tensor.core import apply | ||||
from .math import topk as _topk | from .math import topk as _topk | ||||
from .tensor import transpose as _transpose | |||||
from .tensor import broadcast_to, transpose | |||||
def accuracy( | def accuracy( | ||||
@@ -54,8 +54,8 @@ def accuracy( | |||||
_, pred = _topk(logits, k=max(topk), descending=True) | _, pred = _topk(logits, k=max(topk), descending=True) | ||||
accs = [] | accs = [] | ||||
for k in topk: | for k in topk: | ||||
correct = pred[:, :k].detach() == _transpose(target, (0, "x")).broadcast( | |||||
target.shape[0], k | |||||
correct = pred[:, :k].detach() == broadcast_to( | |||||
transpose(target, (0, "x")), (target.shape[0], k) | |||||
) | ) | ||||
accs.append(correct.astype(np.float32).sum() / target.shape[0]) | accs.append(correct.astype(np.float32).sum() / target.shape[0]) | ||||
if len(topk) == 1: # type: ignore[arg-type] | if len(topk) == 1: # type: ignore[arg-type] | ||||
@@ -319,7 +319,7 @@ def test_Broadcast(): | |||||
x = TensorWrapper(x_np) | x = TensorWrapper(x_np) | ||||
grad = Grad().wrt(x, callback=save_to(x)) | grad = Grad().wrt(x, callback=save_to(x)) | ||||
y = F.broadcast(x, (3, 3, 10)) | |||||
y = F.broadcast_to(x, (3, 3, 10)) | |||||
grad(y, F.ones_like(y)) | grad(y, F.ones_like(y)) | ||||
np.testing.assert_equal(np.ones((3, 3, 1), dtype=np.float32) * 10, x.grad.numpy()) | np.testing.assert_equal(np.ones((3, 3, 1), dtype=np.float32) * 10, x.grad.numpy()) | ||||
@@ -251,17 +251,17 @@ def test_broadcast(): | |||||
{"input": [data1, output1_shape], "output": output1_shape}, | {"input": [data1, output1_shape], "output": output1_shape}, | ||||
{"input": [data2, output2_shape], "output": output2_shape}, | {"input": [data2, output2_shape], "output": output2_shape}, | ||||
] | ] | ||||
opr_test(cases, F.broadcast, compare_fn=compare_fn) | |||||
opr_test(cases, F.broadcast_to, compare_fn=compare_fn) | |||||
x = F.ones((2, 1, 3)) | x = F.ones((2, 1, 3)) | ||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
F.broadcast(x, (2, 3, 4)) | |||||
F.broadcast_to(x, (2, 3, 4)) | |||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
F.broadcast(x, (4, 1, 3)) | |||||
F.broadcast_to(x, (4, 1, 3)) | |||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
F.broadcast(x, (1, 3)) | |||||
F.broadcast_to(x, (1, 3)) | |||||
def test_utils_astensor1d(): | def test_utils_astensor1d(): | ||||
@@ -351,7 +351,7 @@ def test_trace_broadcast(): | |||||
@trace(symbolic=symbolic, capture_as_const=True) | @trace(symbolic=symbolic, capture_as_const=True) | ||||
def f(x): | def f(x): | ||||
y = x.broadcast((3, 4, 5)) | |||||
y = F.broadcast_to(x, (3, 4, 5)) | |||||
return y | return y | ||||
f(x1) | f(x1) | ||||