GitOrigin-RevId: f61e01270b
release-1.10
@@ -1015,7 +1015,7 @@ class IndexingOneHotBase : public OperatorBase { | |||||
DEF_OPR_PARAM(Axis); | DEF_OPR_PARAM(Axis); | ||||
protected: | protected: | ||||
void deduce_layout_fwd( | |||||
MGE_WIN_DECLSPEC_FUC void deduce_layout_fwd( | |||||
const TensorLayout& src, const TensorLayout& index, TensorLayout& dst); | const TensorLayout& src, const TensorLayout& index, TensorLayout& dst); | ||||
void check_layout_fwd( | void check_layout_fwd( | ||||
const TensorLayout& src, const TensorLayout& index, | const TensorLayout& src, const TensorLayout& index, | ||||
@@ -1558,7 +1558,7 @@ def one_hot(inp: Tensor, num_classes: int) -> Tensor: | |||||
) | ) | ||||
ones_tensor = ones(list(inp.shape) + [1], dtype=inp.dtype, device=inp.device) | ones_tensor = ones(list(inp.shape) + [1], dtype=inp.dtype, device=inp.device) | ||||
op = builtin.IndexingSetOneHot(axis=inp.ndim) | |||||
op = builtin.IndexingSetOneHot(axis=inp.ndim, ndim=inp.ndim) | |||||
(result,) = apply(op, zeros_tensor, inp, ones_tensor) | (result,) = apply(op, zeros_tensor, inp, ones_tensor) | ||||
return result | return result | ||||
@@ -1609,7 +1609,7 @@ def indexing_one_hot( | |||||
array([1.], dtype=float32) | array([1.], dtype=float32) | ||||
""" | """ | ||||
assert isinstance(src, Tensor), "src must be of Tensor type" | assert isinstance(src, Tensor), "src must be of Tensor type" | ||||
op = builtin.IndexingOneHot(axis=axis) | |||||
op = builtin.IndexingOneHot(axis=axis, ndim=src.ndim) | |||||
index = convert_single_value(index, dtype="int32", device=src.device) | index = convert_single_value(index, dtype="int32", device=src.device) | ||||
(result,) = apply(op, src, index) | (result,) = apply(op, src, index) | ||||
if not keepdims: | if not keepdims: | ||||
@@ -393,6 +393,8 @@ def split(inp, nsplits_or_sections, axis=0): | |||||
def _get_idx(index, axis): | def _get_idx(index, axis): | ||||
index_dims = len(index.shape) | index_dims = len(index.shape) | ||||
idx = [] | idx = [] | ||||
if axis < 0: | |||||
axis += index_dims | |||||
for i in range(index_dims): | for i in range(index_dims): | ||||
if i != axis: | if i != axis: | ||||
shape = [1] * index_dims | shape = [1] * index_dims | ||||
@@ -457,21 +459,6 @@ def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor: | |||||
"But the input dims:{}, the index dims:{}".format(input_dims, index_dims) | "But the input dims:{}, the index dims:{}".format(input_dims, index_dims) | ||||
) | ) | ||||
if axis < 0 or axis >= input_dims: | |||||
raise ValueError( | |||||
"Index axis {} is output of bounds, should in range [0 {})".format( | |||||
axis, input_dims | |||||
) | |||||
) | |||||
for i in range(input_dims): | |||||
if i != axis and input_shape[i] != index_shape[i]: | |||||
raise ValueError( | |||||
"The input {} and index {} must have the same size apart from axis {}".format( | |||||
input_shape, index_shape, axis | |||||
) | |||||
) | |||||
idx = _get_idx(index, axis) | idx = _get_idx(index, axis) | ||||
return inp[idx].reshape(index.shape) # pylint: disable=no-member | return inp[idx].reshape(index.shape) # pylint: disable=no-member | ||||
@@ -524,7 +511,7 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor: | |||||
>>> inp = Tensor(np.zeros(shape=(3,5),dtype=np.float32)) | >>> inp = Tensor(np.zeros(shape=(3,5),dtype=np.float32)) | ||||
>>> source = Tensor([[0.9935,0.9465,0.2256,0.8926,0.4396],[0.7723,0.0718,0.5939,0.357,0.4576]]) | >>> source = Tensor([[0.9935,0.9465,0.2256,0.8926,0.4396],[0.7723,0.0718,0.5939,0.357,0.4576]]) | ||||
>>> index = Tensor([[0,2,0,2,1],[2,0,1,1,2]]) | >>> index = Tensor([[0,2,0,2,1],[2,0,1,1,2]]) | ||||
>>> oup = F.scatter(inp, 0, index,source) | |||||
>>> oup = F.scatter(inp, 0, index, source) | |||||
>>> oup.numpy() | >>> oup.numpy() | ||||
array([[0.9935, 0.0718, 0.2256, 0. , 0. ], | array([[0.9935, 0.0718, 0.2256, 0. , 0. ], | ||||
[0. , 0. , 0.5939, 0.357 , 0.4396], | [0. , 0. , 0.5939, 0.357 , 0.4396], | ||||
@@ -540,13 +527,6 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor: | |||||
if input_dims != index_dims or input_dims != source_dims: | if input_dims != index_dims or input_dims != source_dims: | ||||
raise ValueError("The input, source and index tensor must have same dimensions") | raise ValueError("The input, source and index tensor must have same dimensions") | ||||
if axis < 0 or axis >= input_dims: | |||||
raise ValueError( | |||||
"Index axis {} is output of bounds, should in range [0 {})".format( | |||||
axis, input_dims | |||||
) | |||||
) | |||||
for i in range(source_dims): | for i in range(source_dims): | ||||
if source_shape[i] > input_shape[i]: | if source_shape[i] > input_shape[i]: | ||||
raise ValueError( | raise ValueError( | ||||
@@ -792,6 +772,8 @@ def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor: | |||||
>>> out.numpy().shape | >>> out.numpy().shape | ||||
(2, 2, 9) | (2, 2, 9) | ||||
""" | """ | ||||
if start_axis < 0: | |||||
start_axis += len(inp.shape) | |||||
target_shape = tuple(inp.shape[i] for i in range(start_axis)) + (-1,) | target_shape = tuple(inp.shape[i] for i in range(start_axis)) + (-1,) | ||||
if end_axis != -1: | if end_axis != -1: | ||||
target_shape += (*inp.shape[end_axis + 1 :],) | target_shape += (*inp.shape[end_axis + 1 :],) | ||||
@@ -1158,6 +1140,5 @@ def cumsum(inp: Tensor, axis: int): | |||||
[ 4 9 15]], dtype=int32, device=xpux:0) | [ 4 9 15]], dtype=int32, device=xpux:0) | ||||
""" | """ | ||||
assert isinstance(inp, Tensor), "input of cumsum must be type of Tensor" | assert isinstance(inp, Tensor), "input of cumsum must be type of Tensor" | ||||
assert axis >= 0 and axis < inp.ndim, "input axis {} out of bound".format(axis) | |||||
op = builtin.Cumsum(axis=axis, exclusive=False, reverse=False) | op = builtin.Cumsum(axis=axis, exclusive=False, reverse=False) | ||||
return apply(op, inp)[0] | return apply(op, inp)[0] |
@@ -490,6 +490,84 @@ std::optional<ValueRefList> pixelShuffle_grad_rule( | |||||
return imperative::apply(op, inputs); | return imperative::apply(op, inputs); | ||||
} | } | ||||
std::optional<ValueRefList> indexing_grad_rule( | |||||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||||
CustomBackward& backward) { | |||||
auto&& indexing = op.cast_final_safe<IndexingOneHot>(); | |||||
mgb_assert(inputs.size() == 2); | |||||
bool flag = inputs_require_grad[0]; | |||||
auto&& grad_op = IndexingSetOneHot::make(indexing.axis, indexing.ndim); | |||||
SmallVector<ValueRef> inputs2; | |||||
if (flag) { | |||||
inputs2.push_back(get_shape(inputs[0])); | |||||
for (size_t i = 1; i < inputs.size(); ++i) { | |||||
inputs2.push_back(inputs[i]); | |||||
} | |||||
} | |||||
auto maker = CustomGradMaker(backward, inputs.size()); | |||||
maker.output_size(1).output_captured(0, false); | |||||
maker.backward([inputs = std::move(inputs2), | |||||
grad_op_ = std::move(grad_op)](Span<ValueRef> grads) { | |||||
mgb_assert(grads.size() == 1); | |||||
ValueRef grad = grads[0]; | |||||
SmallVector<ValueRef> ret(1); | |||||
if (grad && inputs[0]) { | |||||
ValueRefList args_(inputs.size() + 1); | |||||
auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype()); | |||||
args_[0] = zeros; | |||||
args_[1] = inputs[1]; | |||||
args_[2] = grads[0]; | |||||
ret[0] = imperative::apply(*grad_op_, args_)[0]; | |||||
} | |||||
return ret; | |||||
}); | |||||
maker.finalize(); | |||||
return imperative::apply(op, inputs); | |||||
} | |||||
std::optional<ValueRefList> indexing_set_one_hot_grad_rule( | |||||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||||
CustomBackward& backward) { | |||||
auto&& indexingSetOneHot = op.cast_final_safe<IndexingSetOneHot>(); | |||||
mgb_assert(inputs.size() == 3); | |||||
SmallVector<ValueRef> inputs2; | |||||
inputs2.push_back(get_shape(inputs[0])); | |||||
inputs2.push_back(inputs[1]); | |||||
inputs2.push_back(get_shape(inputs[2])); | |||||
auto maker = CustomGradMaker(backward, inputs.size()); | |||||
maker.output_size(1).output_captured(0, false); | |||||
maker.backward([inputs = std::move(inputs2), | |||||
&indexingSetOneHot](Span<ValueRef> grads) { | |||||
mgb_assert(grads.size() == 1); | |||||
ValueRef grad = grads[0]; | |||||
SmallVector<ValueRef> ret(3); | |||||
if (!grad) { | |||||
return ret; | |||||
} | |||||
if (inputs[0]) { | |||||
auto&& grad_op = IndexingSetOneHot::make( | |||||
indexingSetOneHot.axis, indexingSetOneHot.ndim); | |||||
ValueRefList args_(inputs.size()); | |||||
auto&& zeros = make_empty_tensor(grad.device(), inputs[2], grad.dtype()); | |||||
args_[0] = grads[0]; | |||||
args_[1] = inputs[1]; | |||||
args_[2] = zeros; | |||||
ret[0] = imperative::apply(*grad_op, args_)[0]; | |||||
} | |||||
if (inputs[2]) { | |||||
auto&& grad_op = IndexingOneHot::make( | |||||
indexingSetOneHot.axis, indexingSetOneHot.ndim); | |||||
ValueRefList args_(inputs.size() - 1); | |||||
args_[0] = grads[0]; | |||||
args_[1] = inputs[1]; | |||||
ret[2] = imperative::apply(*grad_op, args_)[0]; | |||||
} | |||||
return ret; | |||||
}); | |||||
maker.finalize(); | |||||
return imperative::apply(op, inputs); | |||||
} | |||||
std::optional<ValueRefList> fastpathcopy_grad_rule( | std::optional<ValueRefList> fastpathcopy_grad_rule( | ||||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | ||||
CustomBackward& backward) { | CustomBackward& backward) { | ||||
@@ -522,6 +600,10 @@ struct Init { | |||||
CustomBackward::register_grad_rule( | CustomBackward::register_grad_rule( | ||||
RemoveAxis::typeinfo(), removeAxis_grad_rule); | RemoveAxis::typeinfo(), removeAxis_grad_rule); | ||||
CustomBackward::register_grad_rule( | CustomBackward::register_grad_rule( | ||||
IndexingOneHot::typeinfo(), indexing_grad_rule); | |||||
CustomBackward::register_grad_rule( | |||||
IndexingSetOneHot::typeinfo(), indexing_set_one_hot_grad_rule); | |||||
CustomBackward::register_grad_rule( | |||||
FastpathCopy::typeinfo(), fastpathcopy_grad_rule); | FastpathCopy::typeinfo(), fastpathcopy_grad_rule); | ||||
CustomBackward::register_grad_rule( | CustomBackward::register_grad_rule( | ||||
PixelShuffle::typeinfo(), pixelShuffle_grad_rule); | PixelShuffle::typeinfo(), pixelShuffle_grad_rule); | ||||
@@ -8,11 +8,15 @@ import megengine as mge | |||||
import megengine.distributed as dist | import megengine.distributed as dist | ||||
import megengine.functional as F | import megengine.functional as F | ||||
import megengine.module as M | import megengine.module as M | ||||
from megengine import Tensor | |||||
from megengine.core import _imperative_rt | |||||
from megengine.core._imperative_rt import CompNode, TensorAttr, imperative | from megengine.core._imperative_rt import CompNode, TensorAttr, imperative | ||||
from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync | from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync | ||||
from megengine.core.autodiff.grad import Grad | from megengine.core.autodiff.grad import Grad | ||||
from megengine.core.ops import builtin | |||||
from megengine.core.ops.builtin import Elemwise, Identity | from megengine.core.ops.builtin import Elemwise, Identity | ||||
from megengine.functional.distributed import remote_recv, remote_send | from megengine.functional.distributed import remote_recv, remote_send | ||||
from megengine.functional.tensor import ones, zeros | |||||
def _elwise(mode): | def _elwise(mode): | ||||
@@ -553,3 +557,46 @@ def test_matmul(): | |||||
if ydim == 1 and transposeB == True: | if ydim == 1 and transposeB == True: | ||||
continue | continue | ||||
test_one(xdim, ydim, transposeA, transposeB) | test_one(xdim, ydim, transposeA, transposeB) | ||||
def test_indexing(): | |||||
x = np.array([[1.0, 2.0]]).astype("float32") | |||||
x = mge.Tensor(x) | |||||
index = mge.Tensor([0]) | |||||
with Grad() as grad: | |||||
grad.wrt(x, callback=save_to(x)) | |||||
def f(x): | |||||
return F.indexing_one_hot(x, index, -1) | |||||
y = f(x) | |||||
grad(y, F.ones_like(y)) | |||||
np.testing.assert_equal(np.array([[1, 0]], dtype=np.float32), x.grad.numpy()) | |||||
def test_indexing_set_one_hot(): | |||||
x = mge.tensor(np.arange(1, 4, dtype=np.int32)) | |||||
with Grad() as grad: | |||||
zeros_tensor = zeros((3, 4), dtype=x.dtype, device=x.device) | |||||
ones_tensor = ones((3, 1), dtype=x.dtype, device=x.device) | |||||
grad.wrt(zeros_tensor, callback=save_to(zeros_tensor)) | |||||
grad.wrt(ones_tensor, callback=save_to(ones_tensor)) | |||||
def f(x): | |||||
op = builtin.IndexingSetOneHot(axis=x.ndim, ndim=x.ndim) | |||||
(result,) = apply(op, zeros_tensor, x, ones_tensor) | |||||
return result | |||||
y = f(x) | |||||
grad(y, F.ones_like(y)) | |||||
np.testing.assert_equal( | |||||
np.array([[1, 0, 1, 1], [1, 1, 0, 1], [1, 1, 1, 0]], dtype=np.int32), | |||||
zeros_tensor.grad.numpy(), | |||||
) | |||||
np.testing.assert_equal( | |||||
np.array([[1], [1], [1]], dtype=np.int32), ones_tensor.grad.numpy(), | |||||
) |
@@ -6,9 +6,7 @@ import pytest | |||||
import megengine.autodiff as ad | import megengine.autodiff as ad | ||||
import megengine.functional as F | import megengine.functional as F | ||||
import megengine.optimizer as optimizer | import megengine.optimizer as optimizer | ||||
from megengine import Parameter | |||||
from megengine import Tensor as tensor | |||||
from megengine import tensor | |||||
from megengine import Parameter, Tensor, tensor | |||||
from megengine.autodiff import Function | from megengine.autodiff import Function | ||||
from megengine.module import Module | from megengine.module import Module | ||||
@@ -3,15 +3,15 @@ import numpy as np | |||||
import pytest | import pytest | ||||
import megengine.functional as F | import megengine.functional as F | ||||
from megengine import tensor | |||||
import megengine.tensor as Tensor | |||||
def test_cross_entropy_with_logits(): | def test_cross_entropy_with_logits(): | ||||
data = tensor([[0, 50], [0, -150]]).astype(np.float32) | |||||
label = tensor([1, 0]).astype(np.int32) | |||||
data = Tensor([[0, 50], [0, -150]]).astype(np.float32) | |||||
label = Tensor([1, 0]).astype(np.int32) | |||||
loss = F.nn.cross_entropy(data, label) | loss = F.nn.cross_entropy(data, label) | ||||
np.testing.assert_allclose(loss.numpy(), 0.0) | np.testing.assert_allclose(loss.numpy(), 0.0) | ||||
label = tensor([0, 1]).astype(np.int32) | |||||
label = Tensor([0, 1]).astype(np.int32) | |||||
loss = F.nn.cross_entropy(data, label) | loss = F.nn.cross_entropy(data, label) | ||||
np.testing.assert_allclose(loss.numpy(), 100) | np.testing.assert_allclose(loss.numpy(), 100) | ||||
@@ -35,19 +35,24 @@ def test_cross_entropy(): | |||||
x[i, y[i]] += np.random.rand() * 2 | x[i, y[i]] += np.random.rand() * 2 | ||||
x = softmax(x) | x = softmax(x) | ||||
l_ref = ref(x, y) | l_ref = ref(x, y) | ||||
l = F.nn.cross_entropy(tensor(x, "float32"), tensor(y, "int32"), with_logits=False) | |||||
l = F.nn.cross_entropy(Tensor(x, "float32"), Tensor(y, "int32"), with_logits=False) | |||||
np.testing.assert_allclose(l.numpy(), l_ref, 1e-6, 1e-6) | np.testing.assert_allclose(l.numpy(), l_ref, 1e-6, 1e-6) | ||||
l1 = F.nn.cross_entropy( | |||||
Tensor(x, "float32"), Tensor(y, "int32"), axis=-1, with_logits=False | |||||
) | |||||
np.testing.assert_allclose(l1.numpy(), l_ref, 1e-6, 1e-6) | |||||
def test_cross_entropy_reduction(): | def test_cross_entropy_reduction(): | ||||
logits = np.random.randn(16, 10) | logits = np.random.randn(16, 10) | ||||
label = np.random.randint(10, size=[16]) | label = np.random.randint(10, size=[16]) | ||||
logits = tensor(logits, dtype="float32") | |||||
label = tensor(label, dtype="int32") | |||||
logits = Tensor(logits, dtype="float32") | |||||
label = Tensor(label, dtype="int32") | |||||
perm = np.random.permutation(16) | perm = np.random.permutation(16) | ||||
logits_perm = tensor(logits[perm], dtype="float32") | |||||
label_perm = tensor(label[perm], dtype="int32") | |||||
logits_perm = Tensor(logits[perm], dtype="float32") | |||||
label_perm = Tensor(label[perm], dtype="int32") | |||||
loss = F.nn.cross_entropy(logits, label, reduction="none") | loss = F.nn.cross_entropy(logits, label, reduction="none") | ||||
loss_perm = F.nn.cross_entropy(logits_perm, label_perm, reduction="none") | loss_perm = F.nn.cross_entropy(logits_perm, label_perm, reduction="none") | ||||
@@ -160,18 +165,18 @@ def _ctc_npy_single_seq(pred, label, blank): | |||||
def test_ctc_loss(): | def test_ctc_loss(): | ||||
def test_func(T, C, N): | def test_func(T, C, N): | ||||
input = np.random.randn(T, N, C) | input = np.random.randn(T, N, C) | ||||
input = F.softmax(tensor(input), axis=-1).numpy() | |||||
input = F.softmax(Tensor(input), axis=-1).numpy() | |||||
input_lengths = np.ones(N, dtype=np.int32) * T | input_lengths = np.ones(N, dtype=np.int32) * T | ||||
target_lengths = np.random.randint(low=1, high=T + 1, size=(N,), dtype=np.int32) | target_lengths = np.random.randint(low=1, high=T + 1, size=(N,), dtype=np.int32) | ||||
target = np.random.randint( | target = np.random.randint( | ||||
low=1, high=C, size=(sum(target_lengths)), dtype=np.int32 | low=1, high=C, size=(sum(target_lengths)), dtype=np.int32 | ||||
) | ) | ||||
input_mge = tensor(input) | |||||
input_lengths_mge = tensor(input_lengths) | |||||
input_mge = Tensor(input) | |||||
input_lengths_mge = Tensor(input_lengths) | |||||
target_mge = tensor(target) | |||||
target_lengths_mge = tensor(target_lengths) | |||||
target_mge = Tensor(target) | |||||
target_lengths_mge = Tensor(target_lengths) | |||||
blank = np.random.randint(C) | blank = np.random.randint(C) | ||||
for method in ["mean", "sum", "none"]: | for method in ["mean", "sum", "none"]: | ||||
@@ -6,7 +6,7 @@ import pytest | |||||
from utils import opr_test | from utils import opr_test | ||||
import megengine.functional as F | import megengine.functional as F | ||||
from megengine import jit, tensor | |||||
from megengine import Tensor, jit, tensor | |||||
from megengine.core._imperative_rt.core2 import apply | from megengine.core._imperative_rt.core2 import apply | ||||
from megengine.core.ops import builtin | from megengine.core.ops import builtin | ||||
@@ -61,37 +61,84 @@ def common_test_reduce(opr, ref_opr): | |||||
def test_sum(): | def test_sum(): | ||||
common_test_reduce(opr=F.sum, ref_opr=np.sum) | common_test_reduce(opr=F.sum, ref_opr=np.sum) | ||||
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||||
y = F.sum(x, axis=-1) | |||||
np.testing.assert_equal(y.numpy(), np.array([6, 15]).astype(np.int32)) | |||||
def test_prod(): | def test_prod(): | ||||
common_test_reduce(opr=F.prod, ref_opr=np.prod) | common_test_reduce(opr=F.prod, ref_opr=np.prod) | ||||
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||||
y = F.prod(x, axis=-2) | |||||
np.testing.assert_equal(y.numpy(), np.array([4, 10, 18]).astype(np.int32)) | |||||
def test_mean(): | def test_mean(): | ||||
common_test_reduce(opr=F.mean, ref_opr=np.mean) | common_test_reduce(opr=F.mean, ref_opr=np.mean) | ||||
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||||
y = F.mean(x, axis=-2) | |||||
np.testing.assert_equal(y.numpy(), np.array([2.5, 3.5, 4.5]).astype(np.float32)) | |||||
def test_var(): | def test_var(): | ||||
common_test_reduce(opr=F.var, ref_opr=np.var) | common_test_reduce(opr=F.var, ref_opr=np.var) | ||||
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||||
y = F.var(x, axis=-2) | |||||
np.testing.assert_equal(y.numpy(), np.array([2.25, 2.25, 2.25]).astype(np.float32)) | |||||
def test_std(): | def test_std(): | ||||
common_test_reduce(opr=F.std, ref_opr=np.std) | common_test_reduce(opr=F.std, ref_opr=np.std) | ||||
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||||
y = F.std(x, axis=-2) | |||||
np.testing.assert_equal(y.numpy(), np.array([1.5, 1.5, 1.5]).astype(np.float32)) | |||||
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||||
y = F.std(x, axis=-2) | |||||
np.testing.assert_equal(y.numpy(), np.array([1.5, 1.5, 1.5]).astype(np.float32)) | |||||
def test_min(): | def test_min(): | ||||
common_test_reduce(opr=F.min, ref_opr=np.min) | common_test_reduce(opr=F.min, ref_opr=np.min) | ||||
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||||
y = F.min(x, axis=-1) | |||||
np.testing.assert_equal(y.numpy(), np.array([1, 4]).astype(np.int32)) | |||||
def test_max(): | def test_max(): | ||||
common_test_reduce(opr=F.max, ref_opr=np.max) | common_test_reduce(opr=F.max, ref_opr=np.max) | ||||
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||||
y = F.max(x, axis=-1) | |||||
np.testing.assert_equal(y.numpy(), np.array([3, 6]).astype(np.int32)) | |||||
def test_argmin(): | def test_argmin(): | ||||
common_test_reduce(opr=F.argmin, ref_opr=np.argmin) | common_test_reduce(opr=F.argmin, ref_opr=np.argmin) | ||||
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||||
y = F.argmin(x, axis=-1) | |||||
np.testing.assert_equal(y.numpy(), np.array([0, 0]).astype(np.int32)) | |||||
def test_argmax(): | def test_argmax(): | ||||
common_test_reduce(opr=F.argmax, ref_opr=np.argmax) | common_test_reduce(opr=F.argmax, ref_opr=np.argmax) | ||||
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||||
y = F.argmax(x, axis=-2) | |||||
np.testing.assert_equal(y.numpy(), np.array([1, 1, 1]).astype(np.int32)) | |||||
def test_norm(): | |||||
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||||
y = F.norm(x, axis=-1) | |||||
np.testing.assert_equal( | |||||
y.numpy().round(decimals=3), np.array([3.742, 8.775]).astype(np.float32) | |||||
) | |||||
def test_sqrt(): | def test_sqrt(): | ||||
@@ -136,7 +183,7 @@ def test_sort_empty(is_symbolic): | |||||
fn_ = fn | fn_ = fn | ||||
data = np.random.random(shape).astype(np.float32) | data = np.random.random(shape).astype(np.float32) | ||||
for _ in range(3): | for _ in range(3): | ||||
outs = fn_(tensor(data)) | |||||
outs = fn_(Tensor(data)) | |||||
ref_outs = (np.sort(data), np.argsort(data)) | ref_outs = (np.sort(data), np.argsort(data)) | ||||
assert len(ref_outs) == len(outs) | assert len(ref_outs) == len(outs) | ||||
for i in range(len(outs)): | for i in range(len(outs)): | ||||
@@ -146,6 +193,12 @@ def test_sort_empty(is_symbolic): | |||||
def test_normalize(): | def test_normalize(): | ||||
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||||
y = F.normalize(x, axis=-1) | |||||
np.testing.assert_equal( | |||||
y.numpy().round(decimals=1), | |||||
np.array([[0.3, 0.5, 0.8], [0.5, 0.6, 0.7]]).astype(np.float32), | |||||
) | |||||
cases = [ | cases = [ | ||||
{"input": np.random.random((2, 3, 12, 12)).astype(np.float32)} for i in range(2) | {"input": np.random.random((2, 3, 12, 12)).astype(np.float32)} for i in range(2) | ||||
@@ -177,11 +230,11 @@ def test_sum_neg_axis(): | |||||
shape = (2, 3) | shape = (2, 3) | ||||
data = np.random.random(shape).astype(np.float32) | data = np.random.random(shape).astype(np.float32) | ||||
for axis in (-1, -2, (-2, 1), (-1, 0)): | for axis in (-1, -2, (-2, 1), (-1, 0)): | ||||
get = F.sum(tensor(data), axis=axis) | |||||
get = F.sum(Tensor(data), axis=axis) | |||||
ref = np.sum(data, axis=axis) | ref = np.sum(data, axis=axis) | ||||
np.testing.assert_allclose(get.numpy(), ref, rtol=1e-6) | np.testing.assert_allclose(get.numpy(), ref, rtol=1e-6) | ||||
with pytest.raises(AssertionError): | with pytest.raises(AssertionError): | ||||
F.sum(tensor(data), axis=(-1, 1)) | |||||
F.sum(Tensor(data), axis=(-1, 1)) | |||||
def test_builtin_reduce(): | def test_builtin_reduce(): | ||||
@@ -204,18 +257,18 @@ def test_non_finite(): | |||||
data = [] | data = [] | ||||
for i in range(2): | for i in range(2): | ||||
data.append(np.random.random(shape).astype(np.float32)) | data.append(np.random.random(shape).astype(np.float32)) | ||||
tensorList = [tensor(x) for x in data] | |||||
tensorList = [Tensor(x) for x in data] | |||||
rst = F.math._check_non_finite(tensorList, 0.7) | rst = F.math._check_non_finite(tensorList, 0.7) | ||||
np.testing.assert_equal(rst.numpy(), [0]) | np.testing.assert_equal(rst.numpy(), [0]) | ||||
for i in range(len(tensorList)): | for i in range(len(tensorList)): | ||||
np.testing.assert_allclose(tensorList[i].numpy() / 0.7, data[i], rtol=1e-6) | np.testing.assert_allclose(tensorList[i].numpy() / 0.7, data[i], rtol=1e-6) | ||||
data[1][0][0][0][0] = float("inf") | data[1][0][0][0][0] = float("inf") | ||||
rst = F.math._check_non_finite([tensor(x) for x in data], 0.7) | |||||
rst = F.math._check_non_finite([Tensor(x) for x in data], 0.7) | |||||
np.testing.assert_equal(rst.numpy(), [1]) | np.testing.assert_equal(rst.numpy(), [1]) | ||||
data[1][0][0][0][0] = float("nan") | data[1][0][0][0][0] = float("nan") | ||||
rst = F.math._check_non_finite([tensor(x) for x in data], 0.7) | |||||
rst = F.math._check_non_finite([Tensor(x) for x in data], 0.7) | |||||
np.testing.assert_equal(rst.numpy(), [1]) | np.testing.assert_equal(rst.numpy(), [1]) | ||||
@@ -237,7 +290,7 @@ def test_topk(descending, sorted, inp1d, kth_only): | |||||
return np.sort(x) | return np.sort(x) | ||||
res = F.topk( | res = F.topk( | ||||
tensor(data), k, descending=descending, no_sort=(not sorted), kth_only=kth_only | |||||
Tensor(data), k, descending=descending, no_sort=(not sorted), kth_only=kth_only | |||||
) | ) | ||||
values, indices = res | values, indices = res | ||||
@@ -268,7 +321,7 @@ def test_reduce_on_empty_tensor(is_trace): | |||||
if is_trace: | if is_trace: | ||||
fn = jit.trace(symbolic=symbolic)(fn) | fn = jit.trace(symbolic=symbolic)(fn) | ||||
for i in range(3): | for i in range(3): | ||||
out = fn(tensor(input, dtype=dtype), axis=axis).numpy() | |||||
out = fn(Tensor(input, dtype=dtype), axis=axis).numpy() | |||||
out_ref = ref_fn(input.astype(dtype), axis=axis) | out_ref = ref_fn(input.astype(dtype), axis=axis) | ||||
np.testing.assert_equal(out, out_ref) | np.testing.assert_equal(out, out_ref) | ||||
@@ -7,7 +7,7 @@ import pytest | |||||
from utils import get_var_value, make_tensor, opr_test | from utils import get_var_value, make_tensor, opr_test | ||||
import megengine.functional as F | import megengine.functional as F | ||||
from megengine import tensor | |||||
from megengine import Tensor | |||||
from megengine.core._trace_option import use_symbolic_shape | from megengine.core._trace_option import use_symbolic_shape | ||||
from megengine.core.tensor import megbrain_graph as G | from megengine.core.tensor import megbrain_graph as G | ||||
from megengine.core.tensor.utils import astensor1d | from megengine.core.tensor.utils import astensor1d | ||||
@@ -30,7 +30,7 @@ def test_eye(): | |||||
np.eye(*case["input"]).astype(dtype), | np.eye(*case["input"]).astype(dtype), | ||||
) | ) | ||||
np.testing.assert_allclose( | np.testing.assert_allclose( | ||||
F.eye(tensor(case["input"]), dtype=dtype).numpy(), | |||||
F.eye(Tensor(case["input"]), dtype=dtype).numpy(), | |||||
np.eye(*case["input"]).astype(dtype), | np.eye(*case["input"]).astype(dtype), | ||||
) | ) | ||||
@@ -60,7 +60,21 @@ def test_full(): | |||||
values = [True, 4, 5.0] | values = [True, 4, 5.0] | ||||
for value in values: | for value in values: | ||||
np.testing.assert_allclose(F.full(shape, value).numpy(), np.full(shape, value)) | np.testing.assert_allclose(F.full(shape, value).numpy(), np.full(shape, value)) | ||||
assert F.full(shape, value).dtype == tensor(value).dtype | |||||
assert F.full(shape, value).dtype == Tensor(value).dtype | |||||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||||
def test_cumsum(is_varnode): | |||||
if is_varnode: | |||||
network = Network() | |||||
else: | |||||
network = None | |||||
x = Tensor([[1, 2, 3], [4, 5, 6]], np.int32) | |||||
y = F.cumsum(x, -1) | |||||
np.testing.assert_equal( | |||||
y.numpy(), np.array([[1, 3, 6], [4, 9, 15]]).astype(np.int32) | |||||
) | |||||
@pytest.mark.parametrize("is_varnode", [True, False]) | @pytest.mark.parametrize("is_varnode", [True, False]) | ||||
@@ -83,6 +97,14 @@ def test_concat(is_varnode): | |||||
cases = [{"input": [data1, data2]}, {"input": [data1, data3]}] | cases = [{"input": [data1, data2]}, {"input": [data1, data3]}] | ||||
opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network) | opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network) | ||||
x1 = Tensor(np.arange(0, 6, dtype=np.float32).reshape((2, 3))) | |||||
x2 = Tensor(np.arange(6, 12, dtype=np.float32).reshape((2, 3))) | |||||
y = F.concat([x1, x2], axis=-1) | |||||
np.testing.assert_equal( | |||||
y.numpy(), | |||||
np.array([[0, 1, 2, 6, 7, 8], [3, 4, 5, 9, 10, 11]]).astype(np.float32), | |||||
) | |||||
@pytest.mark.parametrize("is_varnode", [True, False]) | @pytest.mark.parametrize("is_varnode", [True, False]) | ||||
def test_condtake(is_varnode): | def test_condtake(is_varnode): | ||||
@@ -139,6 +161,20 @@ def test_stack(is_varnode): | |||||
cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai), network=network | cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai), network=network | ||||
) | ) | ||||
x1 = Tensor(np.arange(0, 3, dtype=np.float32).reshape((3))) | |||||
x2 = Tensor(np.arange(6, 9, dtype=np.float32).reshape((3))) | |||||
y = F.stack([x1, x2], axis=-1) | |||||
np.testing.assert_equal( | |||||
y.numpy(), np.array([[0, 6], [1, 7], [2, 8]]).astype(np.float32) | |||||
) | |||||
x1 = Tensor(np.arange(0, 3, dtype=np.float32).reshape((3))) | |||||
x2 = Tensor(np.arange(6, 9, dtype=np.float32).reshape((3))) | |||||
y = F.stack([x1, x2], axis=-1) | |||||
np.testing.assert_equal( | |||||
y.numpy(), np.array([[0, 6], [1, 7], [2, 8]]).astype(np.float32) | |||||
) | |||||
@pytest.mark.parametrize("is_varnode", [True, False]) | @pytest.mark.parametrize("is_varnode", [True, False]) | ||||
def test_split_basic(is_varnode): | def test_split_basic(is_varnode): | ||||
@@ -183,6 +219,12 @@ def test_split_basic(is_varnode): | |||||
@pytest.mark.parametrize("symbolic", [None, False, True]) | @pytest.mark.parametrize("symbolic", [None, False, True]) | ||||
def test_split(symbolic): | def test_split(symbolic): | ||||
x = Tensor(np.random.random((10, 20)), dtype=np.float32) | |||||
y = F.split(x, 3, axis=-1) | |||||
z = F.split(x, [6, 17], axis=-1) | |||||
assert str([i.numpy().shape for i in y]) == "[(10, 7), (10, 7), (10, 6)]" | |||||
assert str([i.numpy().shape for i in z]) == "[(10, 6), (10, 11), (10, 3)]" | |||||
inp1 = np.random.random((3, 4, 5, 6)).astype(np.float32) | inp1 = np.random.random((3, 4, 5, 6)).astype(np.float32) | ||||
inp2 = np.random.random((0, 4, 5, 6)).astype(np.float32) | inp2 = np.random.random((0, 4, 5, 6)).astype(np.float32) | ||||
@@ -208,12 +250,43 @@ def test_split(symbolic): | |||||
fn = trace(symbolic=symbolic)(func) | fn = trace(symbolic=symbolic)(func) | ||||
for i in range(3 if symbolic is not None else 1): | for i in range(3 if symbolic is not None else 1): | ||||
ref_out = ref(*case) | ref_out = ref(*case) | ||||
out = fn(tensor(case[0]), case[1], case[2]) | |||||
out = fn(Tensor(case[0]), case[1], case[2]) | |||||
assert len(ref_out) == len(out) | assert len(ref_out) == len(out) | ||||
for idx in range(len(ref_out)): | for idx in range(len(ref_out)): | ||||
np.testing.assert_equal(ref_out[idx], out[idx].numpy()) | np.testing.assert_equal(ref_out[idx], out[idx].numpy()) | ||||
def test_gather(): | |||||
x = Tensor([[1, 2], [3, 4], [5, 6],]) | |||||
index = Tensor([[0, 1], [1, 0], [1, 1]]) | |||||
y = F.gather(x, 1, index) | |||||
np.testing.assert_equal( | |||||
y.numpy(), np.array([[1, 2], [4, 3], [6, 6]]).astype(np.int32) | |||||
) | |||||
def test_scatter(): | |||||
x = Tensor(np.zeros(shape=(3, 5), dtype=np.float32)) | |||||
source = Tensor( | |||||
[ | |||||
[0.9935, 0.9465, 0.2256, 0.8926, 0.4396], | |||||
[0.7723, 0.0718, 0.5939, 0.357, 0.4576], | |||||
] | |||||
) | |||||
index = Tensor([[0, 2, 0, 2, 1], [2, 0, 1, 1, 2]]) | |||||
y = F.scatter(x, -2, index, source) | |||||
np.testing.assert_equal( | |||||
y.numpy().round(decimals=4), | |||||
np.array( | |||||
[ | |||||
[0.9935, 0.0718, 0.2256, 0.0, 0.0], | |||||
[0.0, 0.0, 0.5939, 0.357, 0.4396], | |||||
[0.7723, 0.9465, 0.0, 0.8926, 0.4576], | |||||
] | |||||
).astype(np.float32), | |||||
) | |||||
@pytest.mark.parametrize("is_varnode", [True, False]) | @pytest.mark.parametrize("is_varnode", [True, False]) | ||||
def test_swapaxes(is_varnode): | def test_swapaxes(is_varnode): | ||||
if is_varnode: | if is_varnode: | ||||
@@ -221,7 +294,7 @@ def test_swapaxes(is_varnode): | |||||
else: | else: | ||||
network = None | network = None | ||||
x = tensor(np.array([[1, 2, 3]], dtype=np.int32)) | |||||
x = Tensor(np.array([[1, 2, 3]], dtype=np.int32)) | |||||
y = F.swapaxes(x, 0, 1) | y = F.swapaxes(x, 0, 1) | ||||
np.testing.assert_equal(y.numpy(), np.array([[1], [2], [3]]).astype(np.int32)) | np.testing.assert_equal(y.numpy(), np.array([[1], [2], [3]]).astype(np.int32)) | ||||
@@ -280,15 +353,15 @@ def test_broadcast_auto_infer(is_varnode): | |||||
def test_reshape_on_empty_tensor(is_trace): | def test_reshape_on_empty_tensor(is_trace): | ||||
input1_shape = (100, 0, 1) | input1_shape = (100, 0, 1) | ||||
output1_shape = (100, 0, 10) | output1_shape = (100, 0, 10) | ||||
data1 = tensor(np.random.random(input1_shape).astype(np.float32)) | |||||
data1 = Tensor(np.random.random(input1_shape).astype(np.float32)) | |||||
input2_shape = (10, 0) | input2_shape = (10, 0) | ||||
output2_shape = (0,) | output2_shape = (0,) | ||||
data2 = tensor(np.random.random(input2_shape).astype(np.float32)) | |||||
data2 = Tensor(np.random.random(input2_shape).astype(np.float32)) | |||||
input3_shape = (10, 0, 10) | input3_shape = (10, 0, 10) | ||||
output3_shape = (0, 1, 2, 3) | output3_shape = (0, 1, 2, 3) | ||||
data3 = tensor(np.random.random(input3_shape).astype(np.float32)) | |||||
data3 = Tensor(np.random.random(input3_shape).astype(np.float32)) | |||||
def comp(out, target_shp): | def comp(out, target_shp): | ||||
assert out._tuple_shape == target_shp | assert out._tuple_shape == target_shp | ||||
@@ -338,7 +411,7 @@ def test_reshape_shape_inference(is_varnode): | |||||
def check_shape(output, target): | def check_shape(output, target): | ||||
source = output.shape | source = output.shape | ||||
if isinstance(source, tensor): | |||||
if isinstance(source, Tensor): | |||||
source = source.numpy() | source = source.numpy() | ||||
np.testing.assert_equal(source, target.shape) | np.testing.assert_equal(source, target.shape) | ||||
@@ -366,6 +439,10 @@ def test_squeeze(is_varnode): | |||||
else: | else: | ||||
network = None | network = None | ||||
x = Tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1)) | |||||
y = F.squeeze(x, -1) | |||||
np.testing.assert_equal(y.numpy(), np.array([[[1, 2]]]).astype(np.int32)) | |||||
x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1) | x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1) | ||||
xx = make_tensor(x, network) | xx = make_tensor(x, network) | ||||
@@ -385,6 +462,12 @@ def test_expand_dims(is_varnode): | |||||
else: | else: | ||||
network = None | network = None | ||||
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||||
y = F.expand_dims(x, -1) | |||||
np.testing.assert_equal( | |||||
y.numpy(), np.array([[[1], [2], [3]], [[4], [5], [6]]]).astype(np.int32) | |||||
) | |||||
x = np.arange(6, dtype="float32").reshape(2, 3) | x = np.arange(6, dtype="float32").reshape(2, 3) | ||||
xx = make_tensor(x, network) | xx = make_tensor(x, network) | ||||
@@ -533,6 +616,22 @@ def test_flatten(is_varnode): | |||||
else: | else: | ||||
network = None | network = None | ||||
inp_shape = (2, 2, 3, 3) | |||||
x = Tensor(np.arange(36, dtype=np.int32).reshape(inp_shape),) | |||||
y = F.flatten(x, -2, -1) | |||||
np.testing.assert_equal( | |||||
y.numpy(), | |||||
np.array( | |||||
[ | |||||
[[0, 1, 2, 3, 4, 5, 6, 7, 8], [9, 10, 11, 12, 13, 14, 15, 16, 17]], | |||||
[ | |||||
[18, 19, 20, 21, 22, 23, 24, 25, 26], | |||||
[27, 28, 29, 30, 31, 32, 33, 34, 35], | |||||
], | |||||
] | |||||
).astype(np.int32), | |||||
) | |||||
data0_shape = (2, 3, 4, 5) | data0_shape = (2, 3, 4, 5) | ||||
data1_shape = (4, 5, 6, 7) | data1_shape = (4, 5, 6, 7) | ||||
data0 = np.random.random(data0_shape).astype(np.float32) | data0 = np.random.random(data0_shape).astype(np.float32) | ||||
@@ -616,15 +715,15 @@ def test_broadcast(is_varnode): | |||||
def test_broadcast_on_empty_tensor(is_trace): | def test_broadcast_on_empty_tensor(is_trace): | ||||
input1_shape = (100, 0, 1) | input1_shape = (100, 0, 1) | ||||
output1_shape = (100, 0, 10) | output1_shape = (100, 0, 10) | ||||
data1 = tensor(np.random.random(input1_shape).astype(np.float32)) | |||||
data1 = Tensor(np.random.random(input1_shape).astype(np.float32)) | |||||
input2_shape = (10, 0) | input2_shape = (10, 0) | ||||
output2_shape = (10, 10, 0) | output2_shape = (10, 10, 0) | ||||
data2 = tensor(np.random.random(input2_shape).astype(np.float32)) | |||||
data2 = Tensor(np.random.random(input2_shape).astype(np.float32)) | |||||
input3_shape = (0, 0, 1, 10) | input3_shape = (0, 0, 1, 10) | ||||
output3_shape = (10, 0, 0, 10, 10) | output3_shape = (10, 0, 0, 10, 10) | ||||
data3 = tensor(np.random.random(input3_shape).astype(np.float32)) | |||||
data3 = Tensor(np.random.random(input3_shape).astype(np.float32)) | |||||
def comp(out, target_shp): | def comp(out, target_shp): | ||||
assert out._tuple_shape == target_shp | assert out._tuple_shape == target_shp | ||||
@@ -705,7 +804,7 @@ def test_utils_astensor1d(is_varnode): | |||||
def test_device(): | def test_device(): | ||||
x = tensor([1, 2, 3], dtype="float32") | |||||
x = Tensor([1, 2, 3], dtype="float32") | |||||
y1 = F.eye(x.shape, dtype="float32") | y1 = F.eye(x.shape, dtype="float32") | ||||
y2 = F.eye(x.shape, dtype="float32", device=None) | y2 = F.eye(x.shape, dtype="float32", device=None) | ||||
@@ -789,7 +888,7 @@ def test_copy_d2d(is_varnode): | |||||
) | ) | ||||
@pytest.mark.parametrize("is_symbolic", [None, True, False]) | @pytest.mark.parametrize("is_symbolic", [None, True, False]) | ||||
def test_copy_empty(shape, device_src, device_dst, is_symbolic): | def test_copy_empty(shape, device_src, device_dst, is_symbolic): | ||||
inp = tensor(np.random.randn(*shape).astype("float32"), device=device_src) | |||||
inp = Tensor(np.random.randn(*shape).astype("float32"), device=device_src) | |||||
def func(inp): | def func(inp): | ||||
return F.copy(inp, device_dst) | return F.copy(inp, device_dst) | ||||
@@ -885,6 +984,12 @@ def test_roll(shape, shifts, axis, is_varnode): | |||||
else: | else: | ||||
network = None | network = None | ||||
x = Tensor([[1, 2], [3, 4], [5, 6]], np.int32) | |||||
y = F.roll(x, 1, -1) | |||||
np.testing.assert_equal( | |||||
y.numpy(), np.array([[2, 1], [4, 3], [6, 5]]).astype(np.int32) | |||||
) | |||||
inp = np.random.randn(*shape).astype("float32") | inp = np.random.randn(*shape).astype("float32") | ||||
def func(inp): | def func(inp): | ||||
@@ -904,7 +1009,7 @@ def test_roll(shape, shifts, axis, is_varnode): | |||||
) | ) | ||||
@pytest.mark.parametrize("is_symbolic", [None, True, False]) | @pytest.mark.parametrize("is_symbolic", [None, True, False]) | ||||
def test_roll_empty_tensor(shape, shifts, axis, is_symbolic): | def test_roll_empty_tensor(shape, shifts, axis, is_symbolic): | ||||
inp = tensor(np.random.randn(*shape).astype("float32")) | |||||
inp = Tensor(np.random.randn(*shape).astype("float32")) | |||||
def func(inp): | def func(inp): | ||||
return F.roll(inp, shifts, axis) | return F.roll(inp, shifts, axis) | ||||
@@ -1,8 +1,10 @@ | |||||
#include "../dnn_op_helper.h" | |||||
#include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
#include "../op_trait.h" | #include "../op_trait.h" | ||||
#include "megbrain/opr/indexing.h" | #include "megbrain/opr/indexing.h" | ||||
#include "megdnn/oprs/general.h" | |||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | namespace imperative { | ||||
@@ -12,10 +14,8 @@ namespace indexing_one_hot { | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | ||||
const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) { | const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) { | ||||
auto& op = def.cast_final_safe<IndexingOneHot>(); | |||||
auto&& op = def.cast_final_safe<IndexingOneHot>(); | |||||
mgb_assert(input_descs.size() == 2, "IndexingOneHot expects two inputs"); | mgb_assert(input_descs.size() == 2, "IndexingOneHot expects two inputs"); | ||||
auto comp_node = input_descs[0].comp_node; | auto comp_node = input_descs[0].comp_node; | ||||
TensorLayout src = input_descs[0].layout, index = input_descs[1].layout; | TensorLayout src = input_descs[0].layout, index = input_descs[1].layout; | ||||
@@ -28,10 +28,15 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
mgb_assert(src.ndim >= 2, "src ndim must be at least 2"); | mgb_assert(src.ndim >= 2, "src ndim must be at least 2"); | ||||
mgb_assert(src.is_contiguous(), "src should be contiguous"); | mgb_assert(src.is_contiguous(), "src should be contiguous"); | ||||
mgb_assert( | mgb_assert( | ||||
op.axis >= 0 && op.axis < src.ndim, "axis %d not exists in src", op.axis); | |||||
-static_cast<int>(src.ndim) <= op.axis && | |||||
op.axis < static_cast<int>(src.ndim), | |||||
"axis %d not exists in src", op.axis); | |||||
int real_axis = static_cast<int>(op.axis); | |||||
if (real_axis < 0) { | |||||
real_axis += static_cast<int>(src.ndim); | |||||
} | |||||
TensorLayout dst = src; | TensorLayout dst = src; | ||||
dst.shape[op.axis] = 1; | |||||
dst.shape[real_axis] = 1; | |||||
dst.init_contiguous_stride(); | dst.init_contiguous_stride(); | ||||
if (!index.ndim) { | if (!index.ndim) { | ||||
@@ -40,24 +45,128 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
mgb_assert(index.is_contiguous(), "index should be all contiguous"); | mgb_assert(index.is_contiguous(), "index should be all contiguous"); | ||||
mgb_assert( | mgb_assert( | ||||
index.eq_shape(src.remove_axis(op.axis)), "index shape doesn't match src"); | |||||
index.eq_shape(src.remove_axis(real_axis)), | |||||
"index shape doesn't match src"); | |||||
return {{{dst, comp_node}}, true}; | return {{{dst, comp_node}}, true}; | ||||
} | } | ||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | ||||
auto&& op = static_cast<const IndexingOneHot&>(def); | |||||
auto&& op = def.cast_final_safe<IndexingOneHot>(); | |||||
mgb_assert(inputs.size() == 2); | mgb_assert(inputs.size() == 2); | ||||
int real_axis = static_cast<int>(op.axis); | |||||
if (real_axis < 0) { | |||||
real_axis += static_cast<int>(op.ndim); | |||||
} | |||||
OperatorNodeConfig config{op.make_name()}; | OperatorNodeConfig config{op.make_name()}; | ||||
return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param(), config); | |||||
return opr::IndexingOneHot::make(inputs[0], inputs[1], real_axis, config); | |||||
} | |||||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
const OpDef& def, SmallVector<TensorPtr> inputs, | |||||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||||
auto&& op = def.cast_final_safe<IndexingOneHot>(); | |||||
auto&& inp = inputs[0]; | |||||
auto&& index = inputs[1]; | |||||
TensorLayout layout = inp->layout(); | |||||
TensorLayout index_layout = index->layout(); | |||||
DnnOprCaller<megdnn::IndexingOneHot> dnn_op(inp->comp_node()); | |||||
auto&& indexing_one_hot_param = dnn_op.op->param(); | |||||
int real_axis = static_cast<int>(op.axis); | |||||
if (real_axis < 0) { | |||||
real_axis += static_cast<int>(layout.ndim); | |||||
} | |||||
mgb_assert( | |||||
0 <= real_axis && real_axis < static_cast<int>(layout.ndim), | |||||
"Dimension out of range (expected to be in range of [%d, %d], but got %d)", | |||||
0, static_cast<int>(layout.ndim) - 1, op.axis); | |||||
indexing_one_hot_param = real_axis; | |||||
TensorLayout tlayout; | |||||
dnn_op.op->deduce_layout(layout, index_layout, tlayout); | |||||
TensorPtr out = Tensor::make(tlayout, inp->comp_node()); | |||||
megdnn::TensorND in = inp->dnn_tensor(); | |||||
megdnn::TensorND ind = index->dnn_tensor(); | |||||
TensorLayout m_layout( | |||||
{dnn_op.op->get_workspace_in_bytes(layout, index_layout, tlayout)}, | |||||
dtype::Byte()); | |||||
auto dnn_workspace = dnn_op.create_workspace(m_layout); | |||||
dnn_op.op->exec(in, ind, out->dnn_tensor(), dnn_workspace); | |||||
return {out}; | |||||
} | } | ||||
OP_TRAIT_REG(IndexingOneHot, IndexingOneHot) | OP_TRAIT_REG(IndexingOneHot, IndexingOneHot) | ||||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | .infer_output_attrs_fallible(infer_output_attrs_fallible) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
.apply_on_physical_tensor(apply_on_physical_tensor) | |||||
.fallback(); | .fallback(); | ||||
} // namespace indexing_one_hot | } // namespace indexing_one_hot | ||||
namespace indexing_set_one_hot { | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) { | |||||
mgb_assert(input_descs.size() == 3, "IndexingSetOneHot expects three inputs"); | |||||
auto comp_node = input_descs[0].comp_node; | |||||
TensorLayout src = input_descs[0].layout, index = input_descs[1].layout; | |||||
mgb_assert(index.dtype == dtype::Int32(), "index dtype must be int32"); | |||||
if (!src.ndim) { | |||||
return {{{{{}, src.dtype}, comp_node}}, false}; | |||||
} | |||||
mgb_assert(src.is_contiguous(), "src should be contiguous"); | |||||
return {{input_descs[0]}, true}; | |||||
} | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const IndexingSetOneHot&>(def); | |||||
mgb_assert(inputs.size() == 3); | |||||
int real_axis = static_cast<int>(op.axis); | |||||
if (real_axis < 0) { | |||||
real_axis += static_cast<int>(op.ndim); | |||||
} | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::IndexingSetOneHot::make( | |||||
inputs[0], inputs[1], inputs[2], real_axis, config); | |||||
} | |||||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
const OpDef& def, SmallVector<TensorPtr> inputs, | |||||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||||
auto&& op = def.cast_final_safe<IndexingSetOneHot>(); | |||||
auto&& inp = inputs[0]; | |||||
auto&& index = inputs[1]; | |||||
auto&& sub = inputs[2]; | |||||
TensorLayout layout = inp->layout(); | |||||
TensorLayout index_layout = index->layout(); | |||||
TensorLayout tlayout = sub->layout(); | |||||
mgb_assert(layout.is_contiguous()); | |||||
DnnOprCaller<megdnn::IndexingSetOneHot> dnn_op(inp->comp_node()); | |||||
auto&& indexing_one_hot_param = dnn_op.op->param(); | |||||
int real_axis = static_cast<int>(op.axis); | |||||
if (real_axis < 0) { | |||||
real_axis += static_cast<int>(layout.ndim); | |||||
} | |||||
indexing_one_hot_param = real_axis; | |||||
TensorPtr out = Tensor::make(layout, inp->comp_node()); | |||||
out->dev_tensor().copy_from_fixlayout(inp->dev_tensor()); | |||||
megdnn::TensorND in = inp->dnn_tensor(); | |||||
megdnn::TensorND ind = index->dnn_tensor(); | |||||
megdnn::TensorND su = sub->dnn_tensor(); | |||||
TensorLayout m_layout( | |||||
{dnn_op.op->get_workspace_in_bytes(layout, index_layout, tlayout)}, | |||||
dtype::Byte()); | |||||
auto dnn_workspace = dnn_op.create_workspace(m_layout); | |||||
dnn_op.op->exec(out->dnn_tensor(), ind, su, dnn_workspace); | |||||
return {out}; | |||||
} | |||||
OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot) | |||||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.apply_on_physical_tensor(apply_on_physical_tensor) | |||||
.fallback(); | |||||
} // namespace indexing_set_one_hot | |||||
} // anonymous namespace | } // anonymous namespace | ||||
} // namespace imperative | } // namespace imperative | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -373,21 +373,6 @@ OP_TRAIT_REG(GroupLocal, GroupLocal).apply_on_var_node(apply_on_var_node).fallba | |||||
} // namespace | } // namespace | ||||
namespace { | namespace { | ||||
namespace indexing_set_one_hot { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const IndexingSetOneHot&>(def); | |||||
mgb_assert(inputs.size() == 3); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::IndexingSetOneHot::make( | |||||
inputs[0], inputs[1], inputs[2], op.param(), config); | |||||
} | |||||
OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
} // namespace indexing_set_one_hot | |||||
} // namespace | |||||
namespace { | |||||
namespace typecvt { | namespace typecvt { | ||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | ||||
auto&& op = static_cast<const TypeCvt&>(def); | auto&& op = static_cast<const TypeCvt&>(def); | ||||
@@ -108,9 +108,17 @@ def Remap: MgbHashableOp<"Remap", [RemapParam]>; | |||||
def Resize: MgbHashableOp<"Resize", [ResizeParam]>; | def Resize: MgbHashableOp<"Resize", [ResizeParam]>; | ||||
def IndexingOneHot: MgbHashableOp<"IndexingOneHot", [AxisParam]>; | |||||
def IndexingOneHot: MgbHashableOp<"IndexingOneHot", [AxisParam]> { | |||||
let extraArguments = (ins | |||||
MgbI32Attr:$ndim | |||||
); | |||||
} | |||||
def IndexingSetOneHot: MgbHashableOp<"IndexingSetOneHot", [AxisParam]>; | |||||
def IndexingSetOneHot: MgbHashableOp<"IndexingSetOneHot", [AxisParam]> { | |||||
let extraArguments = (ins | |||||
MgbI32Attr:$ndim | |||||
); | |||||
} | |||||
def Copy: MgbHashableOp<"Copy"> { | def Copy: MgbHashableOp<"Copy"> { | ||||
let extraArguments = (ins | let extraArguments = (ins | ||||