GitOrigin-RevId: a5dc3b997c
release-1.5
@@ -7,12 +7,14 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import numpy as np | import numpy as np | ||||
import pytest | |||||
import megengine.functional as F | import megengine.functional as F | ||||
import megengine.functional.elemwise as elemwise | import megengine.functional.elemwise as elemwise | ||||
from megengine import tensor | from megengine import tensor | ||||
from megengine.core.tensor import dtype | from megengine.core.tensor import dtype | ||||
from megengine.functional.elemwise import Elemwise, _elwise | from megengine.functional.elemwise import Elemwise, _elwise | ||||
from megengine.jit import trace | |||||
def test_abs(): | def test_abs(): | ||||
@@ -180,3 +182,80 @@ def test_int32_input(): | |||||
inp = (x,) * nargs | inp = (x,) * nargs | ||||
y = op(*inp) | y = op(*inp) | ||||
y.numpy() | y.numpy() | ||||
@pytest.mark.parametrize("is_trace", [True, False]) | |||||
def test_empty_tensor(is_trace): | |||||
binary_func = [] | |||||
unary_func = [] | |||||
for op_name in elemwise.__all__: | |||||
op = getattr(elemwise, op_name) | |||||
nargs = op.__code__.co_argcount | |||||
if op_name == "clip": | |||||
unary_func.append(["clip", lambda x, f=op: f(x, lower=0, upper=1)]) | |||||
elif op_name.endswith("_shift"): | |||||
unary_func.append( | |||||
[op_name, lambda x, f=op: f(tensor(x.numpy(), dtype="int32"), 1)] | |||||
) | |||||
elif op_name.startswith("logical_"): # logical_xxx op only accept boolean type | |||||
if nargs == 1: | |||||
unary_func.append( | |||||
[op_name, lambda x, f=op: f(tensor(x.numpy(), dtype="bool"))] | |||||
) | |||||
else: | |||||
assert nargs == 2 | |||||
binary_func.append( | |||||
[ | |||||
op_name, | |||||
lambda x, y, f=op: f( | |||||
tensor(x.numpy(), dtype="bool"), | |||||
tensor(y.numpy(), dtype="bool"), | |||||
), | |||||
] | |||||
) | |||||
elif nargs == 1: | |||||
unary_func.append([op_name, op]) | |||||
elif nargs == 2: | |||||
binary_func.append([op_name, op]) | |||||
else: | |||||
print(nargs) | |||||
raise NotImplementedError | |||||
def run_test(func, args, ref_shape, is_trace, sym=False): | |||||
args = [tensor(t, dtype="float32") for t in args] | |||||
if is_trace: | |||||
func = trace(symbolic=sym)(func) | |||||
for _ in range(3): | |||||
out = func(*args) | |||||
assert out.numpy().shape == ref_shape | |||||
else: | |||||
out = func(*args) | |||||
assert out.numpy().shape == ref_shape | |||||
print(out.numpy().shape) | |||||
inps = [ | |||||
np.array([]).astype("float32"), | |||||
np.random.randn(2, 0, 3).astype("float32"), | |||||
123, | |||||
] | |||||
for op_name, op in unary_func: | |||||
if is_trace: | |||||
for sym in [True, False]: | |||||
run_test(op, [inps[0],], inps[0].shape, True, sym) | |||||
run_test(op, [inps[1],], inps[1].shape, True, sym) | |||||
else: | |||||
run_test(op, [inps[0],], inps[0].shape, False) | |||||
run_test(op, [inps[1],], inps[1].shape, False) | |||||
for op_name, op in binary_func: | |||||
if is_trace: | |||||
for sym in [True, False]: | |||||
run_test(op, [inps[0], inps[0]], (inps[0] + inps[0]).shape, True, sym) | |||||
run_test(op, [inps[1], inps[1]], (inps[1] + inps[1]).shape, True, sym) | |||||
run_test(op, [inps[0], inps[2]], (inps[0] + inps[2]).shape, True, sym) | |||||
run_test(op, [inps[1], inps[2]], (inps[1] + inps[2]).shape, True, sym) | |||||
else: | |||||
run_test(op, [inps[0], inps[0]], (inps[0] + inps[0]).shape, False) | |||||
run_test(op, [inps[1], inps[1]], (inps[1] + inps[1]).shape, False) | |||||
run_test(op, [inps[0], inps[2]], (inps[0] + inps[2]).shape, False) | |||||
run_test(op, [inps[1], inps[2]], (inps[1] + inps[2]).shape, False) |
@@ -19,6 +19,7 @@ from megengine.core._trace_option import use_symbolic_shape | |||||
from megengine.core.tensor import megbrain_graph as G | from megengine.core.tensor import megbrain_graph as G | ||||
from megengine.core.tensor.utils import astensor1d | from megengine.core.tensor.utils import astensor1d | ||||
from megengine.distributed.helper import get_device_count_by_fork | from megengine.distributed.helper import get_device_count_by_fork | ||||
from megengine.jit import trace | |||||
from megengine.utils.network import Network, set_symbolic_shape | from megengine.utils.network import Network, set_symbolic_shape | ||||
from megengine.utils.network_node import VarNode | from megengine.utils.network_node import VarNode | ||||
@@ -177,6 +178,48 @@ def test_reshape(is_varnode): | |||||
np.testing.assert_equal(yy.numpy(), y) | np.testing.assert_equal(yy.numpy(), y) | ||||
@pytest.mark.parametrize("is_trace", [True, False]) | |||||
def test_reshape_on_empty_tensor(is_trace): | |||||
input1_shape = (100, 0, 1) | |||||
output1_shape = (100, 0, 10) | |||||
data1 = tensor(np.random.random(input1_shape).astype(np.float32)) | |||||
input2_shape = (10, 0) | |||||
output2_shape = (0,) | |||||
data2 = tensor(np.random.random(input2_shape).astype(np.float32)) | |||||
input3_shape = (10, 0, 10) | |||||
output3_shape = (0, 1, 2, 3) | |||||
data3 = tensor(np.random.random(input3_shape).astype(np.float32)) | |||||
def comp(out, target_shp): | |||||
assert out._tuple_shape == target_shp | |||||
def func(x, shp): | |||||
return F.reshape(x, shp) | |||||
cases = [ | |||||
[data1, output1_shape], | |||||
[data2, output2_shape], | |||||
[data3, output3_shape], | |||||
] | |||||
def test(func, inp, comp, target_shp): | |||||
out = func(inp, target_shp) | |||||
comp(out, target_shp) | |||||
if is_trace: | |||||
for symbolic in [False, True]: | |||||
for inp, target_shp in cases: | |||||
func_traced = trace(symbolic=symbolic)(func) | |||||
test(func_traced, inp, comp, target_shp) | |||||
test(func_traced, inp, comp, target_shp) | |||||
test(func_traced, inp, comp, target_shp) | |||||
else: | |||||
for inp, target_shp in cases: | |||||
test(func, inp, comp, target_shp) | |||||
@pytest.mark.parametrize("is_varnode", [True, False]) | @pytest.mark.parametrize("is_varnode", [True, False]) | ||||
def test_reshape_shape_inference(is_varnode): | def test_reshape_shape_inference(is_varnode): | ||||
if is_varnode: | if is_varnode: | ||||
@@ -480,6 +523,48 @@ def test_broadcast(is_varnode): | |||||
F.broadcast_to(x, (1, 3)) | F.broadcast_to(x, (1, 3)) | ||||
@pytest.mark.parametrize("is_trace", [True, False]) | |||||
def test_broadcast_on_empty_tensor(is_trace): | |||||
input1_shape = (100, 0, 1) | |||||
output1_shape = (100, 0, 10) | |||||
data1 = tensor(np.random.random(input1_shape).astype(np.float32)) | |||||
input2_shape = (10, 0) | |||||
output2_shape = (10, 10, 0) | |||||
data2 = tensor(np.random.random(input2_shape).astype(np.float32)) | |||||
input3_shape = (0, 0, 1, 10) | |||||
output3_shape = (10, 0, 0, 10, 10) | |||||
data3 = tensor(np.random.random(input3_shape).astype(np.float32)) | |||||
def comp(out, target_shp): | |||||
assert out._tuple_shape == target_shp | |||||
def func(x, shp): | |||||
return F.broadcast_to(x, shp) | |||||
cases = [ | |||||
[data1, output1_shape], | |||||
[data2, output2_shape], | |||||
[data3, output3_shape], | |||||
] | |||||
def test(func, inp, comp, target_shp): | |||||
out = func(inp, target_shp) | |||||
comp(out, target_shp) | |||||
if is_trace: | |||||
for symbolic in [False, True]: | |||||
for inp, target_shp in cases: | |||||
func_traced = trace(symbolic=symbolic)(func) | |||||
test(func_traced, inp, comp, target_shp) | |||||
test(func_traced, inp, comp, target_shp) | |||||
test(func_traced, inp, comp, target_shp) | |||||
else: | |||||
for inp, target_shp in cases: | |||||
test(func, inp, comp, target_shp) | |||||
@pytest.mark.parametrize("is_varnode", [True, False]) | @pytest.mark.parametrize("is_varnode", [True, False]) | ||||
def test_utils_astensor1d(is_varnode): | def test_utils_astensor1d(is_varnode): | ||||
if is_varnode: | if is_varnode: | ||||
@@ -259,6 +259,10 @@ void Elemwise::perform( | |||||
mgb_assert(t.comp_node() == out_cn); | mgb_assert(t.comp_node() == out_cn); | ||||
mgb_assert(t.dtype() == out_dt); | mgb_assert(t.dtype() == out_dt); | ||||
} | } | ||||
if (t.shape().is_empty()) { | |||||
mgb_assert(dest.empty()); | |||||
return; | |||||
} | |||||
inp_shapes[i] = t.shape(); | inp_shapes[i] = t.shape(); | ||||
} | } | ||||
if (!opr) { | if (!opr) { | ||||
@@ -1064,4 +1064,37 @@ TEST(TestOprBasicArithElemwise, EmptyInputOutputBinary) { | |||||
MGB_ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 1, 7})); | MGB_ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 1, 7})); | ||||
} | } | ||||
TEST(TestOprBasicArithElemwise, PerformEmptyIO) { | |||||
auto cn = CompNode::load("xpu0"); | |||||
HostTensorGenerator<> gen; | |||||
auto host_x1 = gen({2, 0, 3, 4}), | |||||
host_x2 = gen({1}); | |||||
auto dev_x1 = std::make_shared<DeviceTensorND>(cn), | |||||
dev_x2 = std::make_shared<DeviceTensorND>(cn); | |||||
dev_x1->copy_from(*host_x1); | |||||
dev_x2->copy_from(*host_x2); | |||||
auto dev_y = std::make_shared<DeviceTensorND>(cn, dev_x1->dtype()); | |||||
dev_y->resize(dev_x1->shape()); | |||||
auto&& dnn_opr = opr::intl::create_megdnn_opr<megdnn::Elemwise>(cn); | |||||
// test unary mode | |||||
for (auto mode: {Mode::NEGATE, Mode::EXP, Mode::LOG}) { | |||||
SmallVector<DeviceTensorND> inputs = {*dev_x1}; | |||||
ASSERT_NO_THROW(opr::Elemwise::perform(mode, *dev_y, inputs, dnn_opr)); | |||||
ASSERT_TRUE(dev_y->empty()); | |||||
ASSERT_TRUE(dev_y->shape().is_empty()); | |||||
MGB_ASSERT_SHAPE_EQ(dev_y->shape(), dev_x1->shape()); | |||||
} | |||||
// test binary mode | |||||
for (auto mode: {Mode::ADD, Mode::MUL, Mode::LT}) { | |||||
SmallVector<DeviceTensorND> inputs = {*dev_x1, *dev_x2}; | |||||
ASSERT_NO_THROW(opr::Elemwise::perform(mode, *dev_y, inputs, dnn_opr)); | |||||
ASSERT_TRUE(dev_y->empty()); | |||||
ASSERT_TRUE(dev_y->shape().is_empty()); | |||||
MGB_ASSERT_SHAPE_EQ(dev_y->shape(), dev_x1->shape()); | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |