Browse Source

fix(imperative): make functional ops support negative axis

GitOrigin-RevId: f61e01270b
release-1.10
Megvii Engine Team 3 years ago
parent
commit
24c5c19bf0
12 changed files with 469 additions and 96 deletions
  1. +1
    -1
      dnn/include/megdnn/oprs/general.h
  2. +2
    -2
      imperative/python/megengine/functional/nn.py
  3. +5
    -24
      imperative/python/megengine/functional/tensor.py
  4. +82
    -0
      imperative/python/src/grad_override.cpp
  5. +47
    -0
      imperative/python/test/unit/core/test_autodiff.py
  6. +1
    -3
      imperative/python/test/unit/core/test_function.py
  7. +19
    -14
      imperative/python/test/unit/functional/test_loss.py
  8. +62
    -9
      imperative/python/test/unit/functional/test_math.py
  9. +120
    -15
      imperative/python/test/unit/functional/test_tensor.py
  10. +120
    -11
      imperative/src/impl/ops/indexing.cpp
  11. +0
    -15
      imperative/src/impl/ops/specializations.cpp
  12. +10
    -2
      src/core/include/megbrain/ir/ops.td

+ 1
- 1
dnn/include/megdnn/oprs/general.h View File

@@ -1015,7 +1015,7 @@ class IndexingOneHotBase : public OperatorBase {
DEF_OPR_PARAM(Axis);

protected:
void deduce_layout_fwd(
MGE_WIN_DECLSPEC_FUC void deduce_layout_fwd(
const TensorLayout& src, const TensorLayout& index, TensorLayout& dst);
void check_layout_fwd(
const TensorLayout& src, const TensorLayout& index,


+ 2
- 2
imperative/python/megengine/functional/nn.py View File

@@ -1558,7 +1558,7 @@ def one_hot(inp: Tensor, num_classes: int) -> Tensor:
)
ones_tensor = ones(list(inp.shape) + [1], dtype=inp.dtype, device=inp.device)

op = builtin.IndexingSetOneHot(axis=inp.ndim)
op = builtin.IndexingSetOneHot(axis=inp.ndim, ndim=inp.ndim)
(result,) = apply(op, zeros_tensor, inp, ones_tensor)
return result

@@ -1609,7 +1609,7 @@ def indexing_one_hot(
array([1.], dtype=float32)
"""
assert isinstance(src, Tensor), "src must be of Tensor type"
op = builtin.IndexingOneHot(axis=axis)
op = builtin.IndexingOneHot(axis=axis, ndim=src.ndim)
index = convert_single_value(index, dtype="int32", device=src.device)
(result,) = apply(op, src, index)
if not keepdims:


+ 5
- 24
imperative/python/megengine/functional/tensor.py View File

@@ -393,6 +393,8 @@ def split(inp, nsplits_or_sections, axis=0):
def _get_idx(index, axis):
index_dims = len(index.shape)
idx = []
if axis < 0:
axis += index_dims
for i in range(index_dims):
if i != axis:
shape = [1] * index_dims
@@ -457,21 +459,6 @@ def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor:
"But the input dims:{}, the index dims:{}".format(input_dims, index_dims)
)

if axis < 0 or axis >= input_dims:
raise ValueError(
"Index axis {} is output of bounds, should in range [0 {})".format(
axis, input_dims
)
)

for i in range(input_dims):
if i != axis and input_shape[i] != index_shape[i]:
raise ValueError(
"The input {} and index {} must have the same size apart from axis {}".format(
input_shape, index_shape, axis
)
)

idx = _get_idx(index, axis)
return inp[idx].reshape(index.shape) # pylint: disable=no-member

@@ -524,7 +511,7 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
>>> inp = Tensor(np.zeros(shape=(3,5),dtype=np.float32))
>>> source = Tensor([[0.9935,0.9465,0.2256,0.8926,0.4396],[0.7723,0.0718,0.5939,0.357,0.4576]])
>>> index = Tensor([[0,2,0,2,1],[2,0,1,1,2]])
>>> oup = F.scatter(inp, 0, index,source)
>>> oup = F.scatter(inp, 0, index, source)
>>> oup.numpy()
array([[0.9935, 0.0718, 0.2256, 0. , 0. ],
[0. , 0. , 0.5939, 0.357 , 0.4396],
@@ -540,13 +527,6 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
if input_dims != index_dims or input_dims != source_dims:
raise ValueError("The input, source and index tensor must have same dimensions")

if axis < 0 or axis >= input_dims:
raise ValueError(
"Index axis {} is output of bounds, should in range [0 {})".format(
axis, input_dims
)
)

for i in range(source_dims):
if source_shape[i] > input_shape[i]:
raise ValueError(
@@ -792,6 +772,8 @@ def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor:
>>> out.numpy().shape
(2, 2, 9)
"""
if start_axis < 0:
start_axis += len(inp.shape)
target_shape = tuple(inp.shape[i] for i in range(start_axis)) + (-1,)
if end_axis != -1:
target_shape += (*inp.shape[end_axis + 1 :],)
@@ -1158,6 +1140,5 @@ def cumsum(inp: Tensor, axis: int):
[ 4 9 15]], dtype=int32, device=xpux:0)
"""
assert isinstance(inp, Tensor), "input of cumsum must be type of Tensor"
assert axis >= 0 and axis < inp.ndim, "input axis {} out of bound".format(axis)
op = builtin.Cumsum(axis=axis, exclusive=False, reverse=False)
return apply(op, inp)[0]

+ 82
- 0
imperative/python/src/grad_override.cpp View File

@@ -490,6 +490,84 @@ std::optional<ValueRefList> pixelShuffle_grad_rule(
return imperative::apply(op, inputs);
}

std::optional<ValueRefList> indexing_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto&& indexing = op.cast_final_safe<IndexingOneHot>();
mgb_assert(inputs.size() == 2);
bool flag = inputs_require_grad[0];
auto&& grad_op = IndexingSetOneHot::make(indexing.axis, indexing.ndim);
SmallVector<ValueRef> inputs2;
if (flag) {
inputs2.push_back(get_shape(inputs[0]));
for (size_t i = 1; i < inputs.size(); ++i) {
inputs2.push_back(inputs[i]);
}
}
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false);
maker.backward([inputs = std::move(inputs2),
grad_op_ = std::move(grad_op)](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
SmallVector<ValueRef> ret(1);
if (grad && inputs[0]) {
ValueRefList args_(inputs.size() + 1);
auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype());
args_[0] = zeros;
args_[1] = inputs[1];
args_[2] = grads[0];
ret[0] = imperative::apply(*grad_op_, args_)[0];
}
return ret;
});
maker.finalize();
return imperative::apply(op, inputs);
}

std::optional<ValueRefList> indexing_set_one_hot_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto&& indexingSetOneHot = op.cast_final_safe<IndexingSetOneHot>();
mgb_assert(inputs.size() == 3);
SmallVector<ValueRef> inputs2;
inputs2.push_back(get_shape(inputs[0]));
inputs2.push_back(inputs[1]);
inputs2.push_back(get_shape(inputs[2]));
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false);
maker.backward([inputs = std::move(inputs2),
&indexingSetOneHot](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
SmallVector<ValueRef> ret(3);
if (!grad) {
return ret;
}
if (inputs[0]) {
auto&& grad_op = IndexingSetOneHot::make(
indexingSetOneHot.axis, indexingSetOneHot.ndim);
ValueRefList args_(inputs.size());
auto&& zeros = make_empty_tensor(grad.device(), inputs[2], grad.dtype());
args_[0] = grads[0];
args_[1] = inputs[1];
args_[2] = zeros;
ret[0] = imperative::apply(*grad_op, args_)[0];
}
if (inputs[2]) {
auto&& grad_op = IndexingOneHot::make(
indexingSetOneHot.axis, indexingSetOneHot.ndim);
ValueRefList args_(inputs.size() - 1);
args_[0] = grads[0];
args_[1] = inputs[1];
ret[2] = imperative::apply(*grad_op, args_)[0];
}
return ret;
});
maker.finalize();
return imperative::apply(op, inputs);
}

std::optional<ValueRefList> fastpathcopy_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
@@ -522,6 +600,10 @@ struct Init {
CustomBackward::register_grad_rule(
RemoveAxis::typeinfo(), removeAxis_grad_rule);
CustomBackward::register_grad_rule(
IndexingOneHot::typeinfo(), indexing_grad_rule);
CustomBackward::register_grad_rule(
IndexingSetOneHot::typeinfo(), indexing_set_one_hot_grad_rule);
CustomBackward::register_grad_rule(
FastpathCopy::typeinfo(), fastpathcopy_grad_rule);
CustomBackward::register_grad_rule(
PixelShuffle::typeinfo(), pixelShuffle_grad_rule);


+ 47
- 0
imperative/python/test/unit/core/test_autodiff.py View File

@@ -8,11 +8,15 @@ import megengine as mge
import megengine.distributed as dist
import megengine.functional as F
import megengine.module as M
from megengine import Tensor
from megengine.core import _imperative_rt
from megengine.core._imperative_rt import CompNode, TensorAttr, imperative
from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync
from megengine.core.autodiff.grad import Grad
from megengine.core.ops import builtin
from megengine.core.ops.builtin import Elemwise, Identity
from megengine.functional.distributed import remote_recv, remote_send
from megengine.functional.tensor import ones, zeros


def _elwise(mode):
@@ -553,3 +557,46 @@ def test_matmul():
if ydim == 1 and transposeB == True:
continue
test_one(xdim, ydim, transposeA, transposeB)


def test_indexing():
x = np.array([[1.0, 2.0]]).astype("float32")
x = mge.Tensor(x)
index = mge.Tensor([0])

with Grad() as grad:
grad.wrt(x, callback=save_to(x))

def f(x):
return F.indexing_one_hot(x, index, -1)

y = f(x)
grad(y, F.ones_like(y))

np.testing.assert_equal(np.array([[1, 0]], dtype=np.float32), x.grad.numpy())


def test_indexing_set_one_hot():
x = mge.tensor(np.arange(1, 4, dtype=np.int32))

with Grad() as grad:
zeros_tensor = zeros((3, 4), dtype=x.dtype, device=x.device)
ones_tensor = ones((3, 1), dtype=x.dtype, device=x.device)

grad.wrt(zeros_tensor, callback=save_to(zeros_tensor))
grad.wrt(ones_tensor, callback=save_to(ones_tensor))

def f(x):
op = builtin.IndexingSetOneHot(axis=x.ndim, ndim=x.ndim)
(result,) = apply(op, zeros_tensor, x, ones_tensor)
return result

y = f(x)
grad(y, F.ones_like(y))
np.testing.assert_equal(
np.array([[1, 0, 1, 1], [1, 1, 0, 1], [1, 1, 1, 0]], dtype=np.int32),
zeros_tensor.grad.numpy(),
)
np.testing.assert_equal(
np.array([[1], [1], [1]], dtype=np.int32), ones_tensor.grad.numpy(),
)

+ 1
- 3
imperative/python/test/unit/core/test_function.py View File

@@ -6,9 +6,7 @@ import pytest
import megengine.autodiff as ad
import megengine.functional as F
import megengine.optimizer as optimizer
from megengine import Parameter
from megengine import Tensor as tensor
from megengine import tensor
from megengine import Parameter, Tensor, tensor
from megengine.autodiff import Function
from megengine.module import Module



+ 19
- 14
imperative/python/test/unit/functional/test_loss.py View File

@@ -3,15 +3,15 @@ import numpy as np
import pytest

import megengine.functional as F
from megengine import tensor
import megengine.tensor as Tensor


def test_cross_entropy_with_logits():
data = tensor([[0, 50], [0, -150]]).astype(np.float32)
label = tensor([1, 0]).astype(np.int32)
data = Tensor([[0, 50], [0, -150]]).astype(np.float32)
label = Tensor([1, 0]).astype(np.int32)
loss = F.nn.cross_entropy(data, label)
np.testing.assert_allclose(loss.numpy(), 0.0)
label = tensor([0, 1]).astype(np.int32)
label = Tensor([0, 1]).astype(np.int32)
loss = F.nn.cross_entropy(data, label)
np.testing.assert_allclose(loss.numpy(), 100)

@@ -35,19 +35,24 @@ def test_cross_entropy():
x[i, y[i]] += np.random.rand() * 2
x = softmax(x)
l_ref = ref(x, y)
l = F.nn.cross_entropy(tensor(x, "float32"), tensor(y, "int32"), with_logits=False)
l = F.nn.cross_entropy(Tensor(x, "float32"), Tensor(y, "int32"), with_logits=False)
np.testing.assert_allclose(l.numpy(), l_ref, 1e-6, 1e-6)

l1 = F.nn.cross_entropy(
Tensor(x, "float32"), Tensor(y, "int32"), axis=-1, with_logits=False
)
np.testing.assert_allclose(l1.numpy(), l_ref, 1e-6, 1e-6)


def test_cross_entropy_reduction():
logits = np.random.randn(16, 10)
label = np.random.randint(10, size=[16])
logits = tensor(logits, dtype="float32")
label = tensor(label, dtype="int32")
logits = Tensor(logits, dtype="float32")
label = Tensor(label, dtype="int32")

perm = np.random.permutation(16)
logits_perm = tensor(logits[perm], dtype="float32")
label_perm = tensor(label[perm], dtype="int32")
logits_perm = Tensor(logits[perm], dtype="float32")
label_perm = Tensor(label[perm], dtype="int32")

loss = F.nn.cross_entropy(logits, label, reduction="none")
loss_perm = F.nn.cross_entropy(logits_perm, label_perm, reduction="none")
@@ -160,18 +165,18 @@ def _ctc_npy_single_seq(pred, label, blank):
def test_ctc_loss():
def test_func(T, C, N):
input = np.random.randn(T, N, C)
input = F.softmax(tensor(input), axis=-1).numpy()
input = F.softmax(Tensor(input), axis=-1).numpy()
input_lengths = np.ones(N, dtype=np.int32) * T
target_lengths = np.random.randint(low=1, high=T + 1, size=(N,), dtype=np.int32)
target = np.random.randint(
low=1, high=C, size=(sum(target_lengths)), dtype=np.int32
)

input_mge = tensor(input)
input_lengths_mge = tensor(input_lengths)
input_mge = Tensor(input)
input_lengths_mge = Tensor(input_lengths)

target_mge = tensor(target)
target_lengths_mge = tensor(target_lengths)
target_mge = Tensor(target)
target_lengths_mge = Tensor(target_lengths)

blank = np.random.randint(C)
for method in ["mean", "sum", "none"]:


+ 62
- 9
imperative/python/test/unit/functional/test_math.py View File

@@ -6,7 +6,7 @@ import pytest
from utils import opr_test

import megengine.functional as F
from megengine import jit, tensor
from megengine import Tensor, jit, tensor
from megengine.core._imperative_rt.core2 import apply
from megengine.core.ops import builtin

@@ -61,37 +61,84 @@ def common_test_reduce(opr, ref_opr):
def test_sum():
common_test_reduce(opr=F.sum, ref_opr=np.sum)

x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.sum(x, axis=-1)
np.testing.assert_equal(y.numpy(), np.array([6, 15]).astype(np.int32))


def test_prod():
common_test_reduce(opr=F.prod, ref_opr=np.prod)

x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.prod(x, axis=-2)
np.testing.assert_equal(y.numpy(), np.array([4, 10, 18]).astype(np.int32))


def test_mean():
common_test_reduce(opr=F.mean, ref_opr=np.mean)

x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.mean(x, axis=-2)
np.testing.assert_equal(y.numpy(), np.array([2.5, 3.5, 4.5]).astype(np.float32))


def test_var():
common_test_reduce(opr=F.var, ref_opr=np.var)

x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.var(x, axis=-2)
np.testing.assert_equal(y.numpy(), np.array([2.25, 2.25, 2.25]).astype(np.float32))


def test_std():
common_test_reduce(opr=F.std, ref_opr=np.std)

x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.std(x, axis=-2)
np.testing.assert_equal(y.numpy(), np.array([1.5, 1.5, 1.5]).astype(np.float32))

x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.std(x, axis=-2)
np.testing.assert_equal(y.numpy(), np.array([1.5, 1.5, 1.5]).astype(np.float32))


def test_min():
common_test_reduce(opr=F.min, ref_opr=np.min)

x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.min(x, axis=-1)
np.testing.assert_equal(y.numpy(), np.array([1, 4]).astype(np.int32))


def test_max():
common_test_reduce(opr=F.max, ref_opr=np.max)

x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.max(x, axis=-1)
np.testing.assert_equal(y.numpy(), np.array([3, 6]).astype(np.int32))


def test_argmin():
common_test_reduce(opr=F.argmin, ref_opr=np.argmin)

x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.argmin(x, axis=-1)
np.testing.assert_equal(y.numpy(), np.array([0, 0]).astype(np.int32))


def test_argmax():
common_test_reduce(opr=F.argmax, ref_opr=np.argmax)
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.argmax(x, axis=-2)
np.testing.assert_equal(y.numpy(), np.array([1, 1, 1]).astype(np.int32))


def test_norm():
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.norm(x, axis=-1)
np.testing.assert_equal(
y.numpy().round(decimals=3), np.array([3.742, 8.775]).astype(np.float32)
)


def test_sqrt():
@@ -136,7 +183,7 @@ def test_sort_empty(is_symbolic):
fn_ = fn
data = np.random.random(shape).astype(np.float32)
for _ in range(3):
outs = fn_(tensor(data))
outs = fn_(Tensor(data))
ref_outs = (np.sort(data), np.argsort(data))
assert len(ref_outs) == len(outs)
for i in range(len(outs)):
@@ -146,6 +193,12 @@ def test_sort_empty(is_symbolic):


def test_normalize():
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.normalize(x, axis=-1)
np.testing.assert_equal(
y.numpy().round(decimals=1),
np.array([[0.3, 0.5, 0.8], [0.5, 0.6, 0.7]]).astype(np.float32),
)

cases = [
{"input": np.random.random((2, 3, 12, 12)).astype(np.float32)} for i in range(2)
@@ -177,11 +230,11 @@ def test_sum_neg_axis():
shape = (2, 3)
data = np.random.random(shape).astype(np.float32)
for axis in (-1, -2, (-2, 1), (-1, 0)):
get = F.sum(tensor(data), axis=axis)
get = F.sum(Tensor(data), axis=axis)
ref = np.sum(data, axis=axis)
np.testing.assert_allclose(get.numpy(), ref, rtol=1e-6)
with pytest.raises(AssertionError):
F.sum(tensor(data), axis=(-1, 1))
F.sum(Tensor(data), axis=(-1, 1))


def test_builtin_reduce():
@@ -204,18 +257,18 @@ def test_non_finite():
data = []
for i in range(2):
data.append(np.random.random(shape).astype(np.float32))
tensorList = [tensor(x) for x in data]
tensorList = [Tensor(x) for x in data]
rst = F.math._check_non_finite(tensorList, 0.7)
np.testing.assert_equal(rst.numpy(), [0])
for i in range(len(tensorList)):
np.testing.assert_allclose(tensorList[i].numpy() / 0.7, data[i], rtol=1e-6)

data[1][0][0][0][0] = float("inf")
rst = F.math._check_non_finite([tensor(x) for x in data], 0.7)
rst = F.math._check_non_finite([Tensor(x) for x in data], 0.7)
np.testing.assert_equal(rst.numpy(), [1])

data[1][0][0][0][0] = float("nan")
rst = F.math._check_non_finite([tensor(x) for x in data], 0.7)
rst = F.math._check_non_finite([Tensor(x) for x in data], 0.7)
np.testing.assert_equal(rst.numpy(), [1])


@@ -237,7 +290,7 @@ def test_topk(descending, sorted, inp1d, kth_only):
return np.sort(x)

res = F.topk(
tensor(data), k, descending=descending, no_sort=(not sorted), kth_only=kth_only
Tensor(data), k, descending=descending, no_sort=(not sorted), kth_only=kth_only
)

values, indices = res
@@ -268,7 +321,7 @@ def test_reduce_on_empty_tensor(is_trace):
if is_trace:
fn = jit.trace(symbolic=symbolic)(fn)
for i in range(3):
out = fn(tensor(input, dtype=dtype), axis=axis).numpy()
out = fn(Tensor(input, dtype=dtype), axis=axis).numpy()
out_ref = ref_fn(input.astype(dtype), axis=axis)
np.testing.assert_equal(out, out_ref)



+ 120
- 15
imperative/python/test/unit/functional/test_tensor.py View File

@@ -7,7 +7,7 @@ import pytest
from utils import get_var_value, make_tensor, opr_test

import megengine.functional as F
from megengine import tensor
from megengine import Tensor
from megengine.core._trace_option import use_symbolic_shape
from megengine.core.tensor import megbrain_graph as G
from megengine.core.tensor.utils import astensor1d
@@ -30,7 +30,7 @@ def test_eye():
np.eye(*case["input"]).astype(dtype),
)
np.testing.assert_allclose(
F.eye(tensor(case["input"]), dtype=dtype).numpy(),
F.eye(Tensor(case["input"]), dtype=dtype).numpy(),
np.eye(*case["input"]).astype(dtype),
)

@@ -60,7 +60,21 @@ def test_full():
values = [True, 4, 5.0]
for value in values:
np.testing.assert_allclose(F.full(shape, value).numpy(), np.full(shape, value))
assert F.full(shape, value).dtype == tensor(value).dtype
assert F.full(shape, value).dtype == Tensor(value).dtype


@pytest.mark.parametrize("is_varnode", [True, False])
def test_cumsum(is_varnode):
if is_varnode:
network = Network()
else:
network = None

x = Tensor([[1, 2, 3], [4, 5, 6]], np.int32)
y = F.cumsum(x, -1)
np.testing.assert_equal(
y.numpy(), np.array([[1, 3, 6], [4, 9, 15]]).astype(np.int32)
)


@pytest.mark.parametrize("is_varnode", [True, False])
@@ -83,6 +97,14 @@ def test_concat(is_varnode):
cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network)

x1 = Tensor(np.arange(0, 6, dtype=np.float32).reshape((2, 3)))
x2 = Tensor(np.arange(6, 12, dtype=np.float32).reshape((2, 3)))
y = F.concat([x1, x2], axis=-1)
np.testing.assert_equal(
y.numpy(),
np.array([[0, 1, 2, 6, 7, 8], [3, 4, 5, 9, 10, 11]]).astype(np.float32),
)


@pytest.mark.parametrize("is_varnode", [True, False])
def test_condtake(is_varnode):
@@ -139,6 +161,20 @@ def test_stack(is_varnode):
cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai), network=network
)

x1 = Tensor(np.arange(0, 3, dtype=np.float32).reshape((3)))
x2 = Tensor(np.arange(6, 9, dtype=np.float32).reshape((3)))
y = F.stack([x1, x2], axis=-1)
np.testing.assert_equal(
y.numpy(), np.array([[0, 6], [1, 7], [2, 8]]).astype(np.float32)
)

x1 = Tensor(np.arange(0, 3, dtype=np.float32).reshape((3)))
x2 = Tensor(np.arange(6, 9, dtype=np.float32).reshape((3)))
y = F.stack([x1, x2], axis=-1)
np.testing.assert_equal(
y.numpy(), np.array([[0, 6], [1, 7], [2, 8]]).astype(np.float32)
)


@pytest.mark.parametrize("is_varnode", [True, False])
def test_split_basic(is_varnode):
@@ -183,6 +219,12 @@ def test_split_basic(is_varnode):

@pytest.mark.parametrize("symbolic", [None, False, True])
def test_split(symbolic):
x = Tensor(np.random.random((10, 20)), dtype=np.float32)
y = F.split(x, 3, axis=-1)
z = F.split(x, [6, 17], axis=-1)
assert str([i.numpy().shape for i in y]) == "[(10, 7), (10, 7), (10, 6)]"
assert str([i.numpy().shape for i in z]) == "[(10, 6), (10, 11), (10, 3)]"

inp1 = np.random.random((3, 4, 5, 6)).astype(np.float32)
inp2 = np.random.random((0, 4, 5, 6)).astype(np.float32)

@@ -208,12 +250,43 @@ def test_split(symbolic):
fn = trace(symbolic=symbolic)(func)
for i in range(3 if symbolic is not None else 1):
ref_out = ref(*case)
out = fn(tensor(case[0]), case[1], case[2])
out = fn(Tensor(case[0]), case[1], case[2])
assert len(ref_out) == len(out)
for idx in range(len(ref_out)):
np.testing.assert_equal(ref_out[idx], out[idx].numpy())


def test_gather():
x = Tensor([[1, 2], [3, 4], [5, 6],])
index = Tensor([[0, 1], [1, 0], [1, 1]])
y = F.gather(x, 1, index)
np.testing.assert_equal(
y.numpy(), np.array([[1, 2], [4, 3], [6, 6]]).astype(np.int32)
)


def test_scatter():
x = Tensor(np.zeros(shape=(3, 5), dtype=np.float32))
source = Tensor(
[
[0.9935, 0.9465, 0.2256, 0.8926, 0.4396],
[0.7723, 0.0718, 0.5939, 0.357, 0.4576],
]
)
index = Tensor([[0, 2, 0, 2, 1], [2, 0, 1, 1, 2]])
y = F.scatter(x, -2, index, source)
np.testing.assert_equal(
y.numpy().round(decimals=4),
np.array(
[
[0.9935, 0.0718, 0.2256, 0.0, 0.0],
[0.0, 0.0, 0.5939, 0.357, 0.4396],
[0.7723, 0.9465, 0.0, 0.8926, 0.4576],
]
).astype(np.float32),
)


@pytest.mark.parametrize("is_varnode", [True, False])
def test_swapaxes(is_varnode):
if is_varnode:
@@ -221,7 +294,7 @@ def test_swapaxes(is_varnode):
else:
network = None

x = tensor(np.array([[1, 2, 3]], dtype=np.int32))
x = Tensor(np.array([[1, 2, 3]], dtype=np.int32))
y = F.swapaxes(x, 0, 1)
np.testing.assert_equal(y.numpy(), np.array([[1], [2], [3]]).astype(np.int32))

@@ -280,15 +353,15 @@ def test_broadcast_auto_infer(is_varnode):
def test_reshape_on_empty_tensor(is_trace):
input1_shape = (100, 0, 1)
output1_shape = (100, 0, 10)
data1 = tensor(np.random.random(input1_shape).astype(np.float32))
data1 = Tensor(np.random.random(input1_shape).astype(np.float32))

input2_shape = (10, 0)
output2_shape = (0,)
data2 = tensor(np.random.random(input2_shape).astype(np.float32))
data2 = Tensor(np.random.random(input2_shape).astype(np.float32))

input3_shape = (10, 0, 10)
output3_shape = (0, 1, 2, 3)
data3 = tensor(np.random.random(input3_shape).astype(np.float32))
data3 = Tensor(np.random.random(input3_shape).astype(np.float32))

def comp(out, target_shp):
assert out._tuple_shape == target_shp
@@ -338,7 +411,7 @@ def test_reshape_shape_inference(is_varnode):

def check_shape(output, target):
source = output.shape
if isinstance(source, tensor):
if isinstance(source, Tensor):
source = source.numpy()
np.testing.assert_equal(source, target.shape)

@@ -366,6 +439,10 @@ def test_squeeze(is_varnode):
else:
network = None

x = Tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1))
y = F.squeeze(x, -1)
np.testing.assert_equal(y.numpy(), np.array([[[1, 2]]]).astype(np.int32))

x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1)
xx = make_tensor(x, network)

@@ -385,6 +462,12 @@ def test_expand_dims(is_varnode):
else:
network = None

x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.expand_dims(x, -1)
np.testing.assert_equal(
y.numpy(), np.array([[[1], [2], [3]], [[4], [5], [6]]]).astype(np.int32)
)

x = np.arange(6, dtype="float32").reshape(2, 3)
xx = make_tensor(x, network)

@@ -533,6 +616,22 @@ def test_flatten(is_varnode):
else:
network = None

inp_shape = (2, 2, 3, 3)
x = Tensor(np.arange(36, dtype=np.int32).reshape(inp_shape),)
y = F.flatten(x, -2, -1)
np.testing.assert_equal(
y.numpy(),
np.array(
[
[[0, 1, 2, 3, 4, 5, 6, 7, 8], [9, 10, 11, 12, 13, 14, 15, 16, 17]],
[
[18, 19, 20, 21, 22, 23, 24, 25, 26],
[27, 28, 29, 30, 31, 32, 33, 34, 35],
],
]
).astype(np.int32),
)

data0_shape = (2, 3, 4, 5)
data1_shape = (4, 5, 6, 7)
data0 = np.random.random(data0_shape).astype(np.float32)
@@ -616,15 +715,15 @@ def test_broadcast(is_varnode):
def test_broadcast_on_empty_tensor(is_trace):
input1_shape = (100, 0, 1)
output1_shape = (100, 0, 10)
data1 = tensor(np.random.random(input1_shape).astype(np.float32))
data1 = Tensor(np.random.random(input1_shape).astype(np.float32))

input2_shape = (10, 0)
output2_shape = (10, 10, 0)
data2 = tensor(np.random.random(input2_shape).astype(np.float32))
data2 = Tensor(np.random.random(input2_shape).astype(np.float32))

input3_shape = (0, 0, 1, 10)
output3_shape = (10, 0, 0, 10, 10)
data3 = tensor(np.random.random(input3_shape).astype(np.float32))
data3 = Tensor(np.random.random(input3_shape).astype(np.float32))

def comp(out, target_shp):
assert out._tuple_shape == target_shp
@@ -705,7 +804,7 @@ def test_utils_astensor1d(is_varnode):


def test_device():
x = tensor([1, 2, 3], dtype="float32")
x = Tensor([1, 2, 3], dtype="float32")

y1 = F.eye(x.shape, dtype="float32")
y2 = F.eye(x.shape, dtype="float32", device=None)
@@ -789,7 +888,7 @@ def test_copy_d2d(is_varnode):
)
@pytest.mark.parametrize("is_symbolic", [None, True, False])
def test_copy_empty(shape, device_src, device_dst, is_symbolic):
inp = tensor(np.random.randn(*shape).astype("float32"), device=device_src)
inp = Tensor(np.random.randn(*shape).astype("float32"), device=device_src)

def func(inp):
return F.copy(inp, device_dst)
@@ -885,6 +984,12 @@ def test_roll(shape, shifts, axis, is_varnode):
else:
network = None

x = Tensor([[1, 2], [3, 4], [5, 6]], np.int32)
y = F.roll(x, 1, -1)
np.testing.assert_equal(
y.numpy(), np.array([[2, 1], [4, 3], [6, 5]]).astype(np.int32)
)

inp = np.random.randn(*shape).astype("float32")

def func(inp):
@@ -904,7 +1009,7 @@ def test_roll(shape, shifts, axis, is_varnode):
)
@pytest.mark.parametrize("is_symbolic", [None, True, False])
def test_roll_empty_tensor(shape, shifts, axis, is_symbolic):
inp = tensor(np.random.randn(*shape).astype("float32"))
inp = Tensor(np.random.randn(*shape).astype("float32"))

def func(inp):
return F.roll(inp, shifts, axis)


+ 120
- 11
imperative/src/impl/ops/indexing.cpp View File

@@ -1,8 +1,10 @@
#include "../dnn_op_helper.h"
#include "megbrain/imperative/ops/autogen.h"

#include "../op_trait.h"

#include "megbrain/opr/indexing.h"
#include "megdnn/oprs/general.h"

namespace mgb {
namespace imperative {
@@ -12,10 +14,8 @@ namespace indexing_one_hot {

std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) {
auto& op = def.cast_final_safe<IndexingOneHot>();

auto&& op = def.cast_final_safe<IndexingOneHot>();
mgb_assert(input_descs.size() == 2, "IndexingOneHot expects two inputs");

auto comp_node = input_descs[0].comp_node;
TensorLayout src = input_descs[0].layout, index = input_descs[1].layout;

@@ -28,10 +28,15 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
mgb_assert(src.ndim >= 2, "src ndim must be at least 2");
mgb_assert(src.is_contiguous(), "src should be contiguous");
mgb_assert(
op.axis >= 0 && op.axis < src.ndim, "axis %d not exists in src", op.axis);

-static_cast<int>(src.ndim) <= op.axis &&
op.axis < static_cast<int>(src.ndim),
"axis %d not exists in src", op.axis);
int real_axis = static_cast<int>(op.axis);
if (real_axis < 0) {
real_axis += static_cast<int>(src.ndim);
}
TensorLayout dst = src;
dst.shape[op.axis] = 1;
dst.shape[real_axis] = 1;
dst.init_contiguous_stride();

if (!index.ndim) {
@@ -40,24 +45,128 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(

mgb_assert(index.is_contiguous(), "index should be all contiguous");
mgb_assert(
index.eq_shape(src.remove_axis(op.axis)), "index shape doesn't match src");
index.eq_shape(src.remove_axis(real_axis)),
"index shape doesn't match src");
return {{{dst, comp_node}}, true};
}

auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const IndexingOneHot&>(def);
auto&& op = def.cast_final_safe<IndexingOneHot>();
mgb_assert(inputs.size() == 2);
int real_axis = static_cast<int>(op.axis);
if (real_axis < 0) {
real_axis += static_cast<int>(op.ndim);
}
OperatorNodeConfig config{op.make_name()};
return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param(), config);
return opr::IndexingOneHot::make(inputs[0], inputs[1], real_axis, config);
}

SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, SmallVector<TensorPtr> inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& op = def.cast_final_safe<IndexingOneHot>();
auto&& inp = inputs[0];
auto&& index = inputs[1];
TensorLayout layout = inp->layout();
TensorLayout index_layout = index->layout();
DnnOprCaller<megdnn::IndexingOneHot> dnn_op(inp->comp_node());
auto&& indexing_one_hot_param = dnn_op.op->param();
int real_axis = static_cast<int>(op.axis);
if (real_axis < 0) {
real_axis += static_cast<int>(layout.ndim);
}
mgb_assert(
0 <= real_axis && real_axis < static_cast<int>(layout.ndim),
"Dimension out of range (expected to be in range of [%d, %d], but got %d)",
0, static_cast<int>(layout.ndim) - 1, op.axis);
indexing_one_hot_param = real_axis;
TensorLayout tlayout;
dnn_op.op->deduce_layout(layout, index_layout, tlayout);
TensorPtr out = Tensor::make(tlayout, inp->comp_node());
megdnn::TensorND in = inp->dnn_tensor();
megdnn::TensorND ind = index->dnn_tensor();
TensorLayout m_layout(
{dnn_op.op->get_workspace_in_bytes(layout, index_layout, tlayout)},
dtype::Byte());
auto dnn_workspace = dnn_op.create_workspace(m_layout);
dnn_op.op->exec(in, ind, out->dnn_tensor(), dnn_workspace);
return {out};
}

OP_TRAIT_REG(IndexingOneHot, IndexingOneHot)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();

} // namespace indexing_one_hot

namespace indexing_set_one_hot {

std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) {
mgb_assert(input_descs.size() == 3, "IndexingSetOneHot expects three inputs");
auto comp_node = input_descs[0].comp_node;
TensorLayout src = input_descs[0].layout, index = input_descs[1].layout;

mgb_assert(index.dtype == dtype::Int32(), "index dtype must be int32");

if (!src.ndim) {
return {{{{{}, src.dtype}, comp_node}}, false};
}
mgb_assert(src.is_contiguous(), "src should be contiguous");
return {{input_descs[0]}, true};
}

auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const IndexingSetOneHot&>(def);
mgb_assert(inputs.size() == 3);
int real_axis = static_cast<int>(op.axis);
if (real_axis < 0) {
real_axis += static_cast<int>(op.ndim);
}
OperatorNodeConfig config{op.make_name()};
return opr::IndexingSetOneHot::make(
inputs[0], inputs[1], inputs[2], real_axis, config);
}

SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, SmallVector<TensorPtr> inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& op = def.cast_final_safe<IndexingSetOneHot>();
auto&& inp = inputs[0];
auto&& index = inputs[1];
auto&& sub = inputs[2];
TensorLayout layout = inp->layout();
TensorLayout index_layout = index->layout();
TensorLayout tlayout = sub->layout();
mgb_assert(layout.is_contiguous());
DnnOprCaller<megdnn::IndexingSetOneHot> dnn_op(inp->comp_node());
auto&& indexing_one_hot_param = dnn_op.op->param();
int real_axis = static_cast<int>(op.axis);
if (real_axis < 0) {
real_axis += static_cast<int>(layout.ndim);
}
indexing_one_hot_param = real_axis;
TensorPtr out = Tensor::make(layout, inp->comp_node());
out->dev_tensor().copy_from_fixlayout(inp->dev_tensor());
megdnn::TensorND in = inp->dnn_tensor();
megdnn::TensorND ind = index->dnn_tensor();
megdnn::TensorND su = sub->dnn_tensor();
TensorLayout m_layout(
{dnn_op.op->get_workspace_in_bytes(layout, index_layout, tlayout)},
dtype::Byte());
auto dnn_workspace = dnn_op.create_workspace(m_layout);
dnn_op.op->exec(out->dnn_tensor(), ind, su, dnn_workspace);
return {out};
}

OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // namespace indexing_set_one_hot

} // anonymous namespace
} // namespace imperative
} // namespace mgb


+ 0
- 15
imperative/src/impl/ops/specializations.cpp View File

@@ -373,21 +373,6 @@ OP_TRAIT_REG(GroupLocal, GroupLocal).apply_on_var_node(apply_on_var_node).fallba
} // namespace

namespace {
namespace indexing_set_one_hot {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const IndexingSetOneHot&>(def);
mgb_assert(inputs.size() == 3);
OperatorNodeConfig config{op.make_name()};
return opr::IndexingSetOneHot::make(
inputs[0], inputs[1], inputs[2], op.param(), config);
}
OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot)
.apply_on_var_node(apply_on_var_node)
.fallback();
} // namespace indexing_set_one_hot
} // namespace

namespace {
namespace typecvt {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const TypeCvt&>(def);


+ 10
- 2
src/core/include/megbrain/ir/ops.td View File

@@ -108,9 +108,17 @@ def Remap: MgbHashableOp<"Remap", [RemapParam]>;

def Resize: MgbHashableOp<"Resize", [ResizeParam]>;

def IndexingOneHot: MgbHashableOp<"IndexingOneHot", [AxisParam]>;
def IndexingOneHot: MgbHashableOp<"IndexingOneHot", [AxisParam]> {
let extraArguments = (ins
MgbI32Attr:$ndim
);
}

def IndexingSetOneHot: MgbHashableOp<"IndexingSetOneHot", [AxisParam]>;
def IndexingSetOneHot: MgbHashableOp<"IndexingSetOneHot", [AxisParam]> {
let extraArguments = (ins
MgbI32Attr:$ndim
);
}

def Copy: MgbHashableOp<"Copy"> {
let extraArguments = (ins


Loading…
Cancel
Save