@@ -100,6 +100,8 @@ class GradManager: | |||
:param ys: outputs of forward operators, e.g., the loss tensor | |||
:param dys: derivatives of ys | |||
""" | |||
from ..functional import ones_like | |||
global backwarding_grad_manager | |||
cache = backwarding_grad_manager | |||
backwarding_grad_manager = self | |||
@@ -113,7 +115,7 @@ class GradManager: | |||
if not isinstance(ys, (tuple, list)): | |||
ys = [ys] | |||
if dys is None: | |||
dys = [tensor(1.0).broadcast(y.shape) for y in ys] | |||
dys = [ones_like(y) for y in ys] | |||
if not isinstance(dys, (tuple, list)): | |||
dys = [dys] | |||
try: | |||
@@ -160,7 +160,7 @@ def subtensor_grad_fn(op, inputs, outputs, input_requires_grad): | |||
def make_grad(grad_op, dy): | |||
grad = ( | |||
TensorWrapper(0, dtype=dy.dtype, device=dy.device) | |||
.broadcast(TensorWrapper(input_shape)) | |||
._broadcast(TensorWrapper(input_shape)) | |||
.__wrapped__ | |||
) | |||
(dx,) = apply(grad_op, grad, dy, *params) | |||
@@ -186,7 +186,7 @@ def indexingMultiAxisVec_grad_fn(op, inputs, outputs, input_requires_grad): | |||
def make_grad(grad_op, dy): | |||
grad = ( | |||
TensorWrapper(0, dtype=dy.dtype, device=dy.device) | |||
.broadcast(TensorWrapper(input_shape)) | |||
._broadcast(TensorWrapper(input_shape)) | |||
.__wrapped__ | |||
) | |||
(dx,) = apply(grad_op, grad, dy, *params) | |||
@@ -267,7 +267,7 @@ def setitem(tensor, index, value): | |||
value.shape, tmp_result.shape | |||
) | |||
) | |||
value = value.broadcast(tmp_result.shape) | |||
value = value._broadcast(tmp_result.shape) | |||
if use_subtensor: | |||
op = builtin.SetSubtensor(items=items) | |||
else: | |||
@@ -396,7 +396,8 @@ class ArrayMethodMixin(abc.ABC): | |||
def reshape(self, *args): | |||
return _reshape(self, _expand_args(args)) | |||
def broadcast(self, *args): | |||
# FIXME: remove this method | |||
def _broadcast(self, *args): | |||
return _broadcast(self, _expand_args(args)) | |||
def transpose(self, *args): | |||
@@ -23,7 +23,16 @@ from .debug_param import get_conv_execution_strategy | |||
from .distributed import all_reduce_sum | |||
from .elemwise import exp, floor, log, log1p, maximum, minimum, relu | |||
from .math import argsort, max, sum | |||
from .tensor import broadcast, concat, expand_dims, full, ones, reshape, squeeze, zeros | |||
from .tensor import ( | |||
broadcast_to, | |||
concat, | |||
expand_dims, | |||
full, | |||
ones, | |||
reshape, | |||
squeeze, | |||
zeros, | |||
) | |||
from .types import _pair, _pair_nonzero | |||
__all__ = [ | |||
@@ -635,7 +644,7 @@ def batch_norm2d( | |||
def full_value(value): | |||
C = inp.shape[1] | |||
(x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) | |||
return broadcast(x, [1, C, 1, 1]) | |||
return broadcast_to(x, [1, C, 1, 1]) | |||
def expand_or_full(x, value): | |||
if x is None: | |||
@@ -754,7 +763,7 @@ def sync_batch_norm( | |||
if is_distributed(): | |||
# reduce all nodes' data to calculate mean and variance | |||
reduce_size = broadcast(Tensor(reduce_size, dtype=_dtype), [1] * _ndim) | |||
reduce_size = broadcast_to(Tensor(reduce_size, dtype=_dtype), [1] * _ndim) | |||
stat = concat( | |||
[reduce_size.astype(_dtype), channel_x1s, channel_x2s], axis=1 | |||
) | |||
@@ -968,10 +977,10 @@ def matmul( | |||
if dim1 != dim2: | |||
if dim1 < dim2: | |||
shape1 = shape2[: dim2 - dim1] + shape1 | |||
inp1 = inp1.broadcast(*shape1) | |||
inp1 = broadcast_to(inp1, shape1) | |||
else: | |||
shape2 = shape1[: dim1 - dim2] + shape2 | |||
inp2 = inp2.broadcast(*shape2) | |||
inp2 = broadcast_to(inp2, shape2) | |||
reshaped_batch_size = 1 | |||
for i in shape1[:-2]: | |||
reshaped_batch_size *= i | |||
@@ -986,9 +995,9 @@ def matmul( | |||
shp = shape1[:-1] + shape2[-1:] | |||
elif dim1 == 3 or dim2 == 3: | |||
if dim2 < 3: | |||
inp2 = inp2.broadcast(*(inp1.shape[:1] + inp2.shape)) | |||
inp2 = broadcast_to(inp2, inp1.shape[:1] + inp2.shape) | |||
elif dim1 < 3: | |||
inp1 = inp1.broadcast(*(inp2.shape[:1] + inp1.shape)) | |||
inp1 = broadcast_to(inp1, inp2.shape[:1] + inp1.shape) | |||
op = builtin.BatchedMatrixMul( | |||
transposeA=transpose_a, | |||
transposeB=transpose_b, | |||
@@ -1205,7 +1214,7 @@ def interpolate( | |||
[row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)], | |||
axis=0, | |||
).reshape(1, 3, 3) | |||
weight = broadcast(weight, (inp.shape[0], 3, 3)) | |||
weight = broadcast_to(weight, (inp.shape[0], 3, 3)) | |||
else: | |||
hscale = 1.0 * ih / oh | |||
wscale = 1.0 * iw / ow | |||
@@ -1221,7 +1230,7 @@ def interpolate( | |||
[row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)], | |||
axis=0, | |||
).reshape(1, 3, 3) | |||
weight = broadcast(weight, (inp.shape[0], 3, 3)) | |||
weight = broadcast_to(weight, (inp.shape[0], 3, 3)) | |||
weight = weight.astype("float32") | |||
ret = warp_perspective(inp, weight, dsize, interp_mode="LINEAR") | |||
@@ -19,7 +19,7 @@ from ..core.ops import builtin | |||
from ..core.ops._internal import param_defs as P | |||
from ..core.ops.special import Const | |||
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | |||
from ..core.tensor.tensor_wrapper import _remove_axis | |||
from ..core.tensor.tensor_wrapper import _broadcast, _remove_axis | |||
from ..core.tensor.utils import ( | |||
astensor1d, | |||
convert_inputs, | |||
@@ -33,7 +33,7 @@ from .elemwise import ceil | |||
__all__ = [ | |||
"arange", | |||
"broadcast", | |||
"broadcast_to", | |||
"concat", | |||
"cond_take", | |||
"expand_dims", | |||
@@ -104,7 +104,7 @@ def full(shape, value, dtype="float32", device=None): | |||
(x,) = Const(value, dtype=dtype, device=device)( | |||
Tensor(value, dtype=dtype, device=device) | |||
) | |||
return broadcast(x, shape) | |||
return broadcast_to(x, shape) | |||
def ones(shape, dtype="float32", device=None): | |||
@@ -192,7 +192,7 @@ def identity(inp: Tensor) -> Tensor: | |||
return output | |||
def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | |||
def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | |||
""" | |||
Broadcasts a tensor to given shape. | |||
@@ -209,7 +209,7 @@ def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | |||
import megengine.functional as F | |||
data = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | |||
out = F.broadcast(data, (4, 2, 3)) | |||
out = F.broadcast_to(data, (4, 2, 3)) | |||
print(out.numpy()) | |||
Outputs: | |||
@@ -229,7 +229,7 @@ def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | |||
[3. 4. 5.]]] | |||
""" | |||
return inp.broadcast(shape) | |||
return _broadcast(inp, shape) | |||
def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: | |||
@@ -395,8 +395,7 @@ def _get_idx(index, axis): | |||
0, index.shape[i] - 1, index.shape[i], device=index.device, | |||
) | |||
arange = ( | |||
arange.reshape(*shape) | |||
.broadcast(index.shape) | |||
broadcast_to(arange.reshape(*shape), index.shape) | |||
.reshape(-1) | |||
.astype(np.int32) | |||
) | |||
@@ -15,7 +15,7 @@ from ..core.ops.builtin import Copy | |||
from ..core.tensor import Tensor | |||
from ..core.tensor.core import apply | |||
from .math import topk as _topk | |||
from .tensor import transpose as _transpose | |||
from .tensor import broadcast_to, transpose | |||
def accuracy( | |||
@@ -54,8 +54,8 @@ def accuracy( | |||
_, pred = _topk(logits, k=max(topk), descending=True) | |||
accs = [] | |||
for k in topk: | |||
correct = pred[:, :k].detach() == _transpose(target, (0, "x")).broadcast( | |||
target.shape[0], k | |||
correct = pred[:, :k].detach() == broadcast_to( | |||
transpose(target, (0, "x")), (target.shape[0], k) | |||
) | |||
accs.append(correct.astype(np.float32).sum() / target.shape[0]) | |||
if len(topk) == 1: # type: ignore[arg-type] | |||
@@ -319,7 +319,7 @@ def test_Broadcast(): | |||
x = TensorWrapper(x_np) | |||
grad = Grad().wrt(x, callback=save_to(x)) | |||
y = F.broadcast(x, (3, 3, 10)) | |||
y = F.broadcast_to(x, (3, 3, 10)) | |||
grad(y, F.ones_like(y)) | |||
np.testing.assert_equal(np.ones((3, 3, 1), dtype=np.float32) * 10, x.grad.numpy()) | |||
@@ -251,17 +251,17 @@ def test_broadcast(): | |||
{"input": [data1, output1_shape], "output": output1_shape}, | |||
{"input": [data2, output2_shape], "output": output2_shape}, | |||
] | |||
opr_test(cases, F.broadcast, compare_fn=compare_fn) | |||
opr_test(cases, F.broadcast_to, compare_fn=compare_fn) | |||
x = F.ones((2, 1, 3)) | |||
with pytest.raises(ValueError): | |||
F.broadcast(x, (2, 3, 4)) | |||
F.broadcast_to(x, (2, 3, 4)) | |||
with pytest.raises(ValueError): | |||
F.broadcast(x, (4, 1, 3)) | |||
F.broadcast_to(x, (4, 1, 3)) | |||
with pytest.raises(ValueError): | |||
F.broadcast(x, (1, 3)) | |||
F.broadcast_to(x, (1, 3)) | |||
def test_utils_astensor1d(): | |||
@@ -351,7 +351,7 @@ def test_trace_broadcast(): | |||
@trace(symbolic=symbolic, capture_as_const=True) | |||
def f(x): | |||
y = x.broadcast((3, 4, 5)) | |||
y = F.broadcast_to(x, (3, 4, 5)) | |||
return y | |||
f(x1) | |||