GitOrigin-RevId: a5dc3b997c
release-1.5
@@ -7,12 +7,14 @@ | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
import pytest | |||
import megengine.functional as F | |||
import megengine.functional.elemwise as elemwise | |||
from megengine import tensor | |||
from megengine.core.tensor import dtype | |||
from megengine.functional.elemwise import Elemwise, _elwise | |||
from megengine.jit import trace | |||
def test_abs(): | |||
@@ -180,3 +182,80 @@ def test_int32_input(): | |||
inp = (x,) * nargs | |||
y = op(*inp) | |||
y.numpy() | |||
@pytest.mark.parametrize("is_trace", [True, False]) | |||
def test_empty_tensor(is_trace): | |||
binary_func = [] | |||
unary_func = [] | |||
for op_name in elemwise.__all__: | |||
op = getattr(elemwise, op_name) | |||
nargs = op.__code__.co_argcount | |||
if op_name == "clip": | |||
unary_func.append(["clip", lambda x, f=op: f(x, lower=0, upper=1)]) | |||
elif op_name.endswith("_shift"): | |||
unary_func.append( | |||
[op_name, lambda x, f=op: f(tensor(x.numpy(), dtype="int32"), 1)] | |||
) | |||
elif op_name.startswith("logical_"): # logical_xxx op only accept boolean type | |||
if nargs == 1: | |||
unary_func.append( | |||
[op_name, lambda x, f=op: f(tensor(x.numpy(), dtype="bool"))] | |||
) | |||
else: | |||
assert nargs == 2 | |||
binary_func.append( | |||
[ | |||
op_name, | |||
lambda x, y, f=op: f( | |||
tensor(x.numpy(), dtype="bool"), | |||
tensor(y.numpy(), dtype="bool"), | |||
), | |||
] | |||
) | |||
elif nargs == 1: | |||
unary_func.append([op_name, op]) | |||
elif nargs == 2: | |||
binary_func.append([op_name, op]) | |||
else: | |||
print(nargs) | |||
raise NotImplementedError | |||
def run_test(func, args, ref_shape, is_trace, sym=False): | |||
args = [tensor(t, dtype="float32") for t in args] | |||
if is_trace: | |||
func = trace(symbolic=sym)(func) | |||
for _ in range(3): | |||
out = func(*args) | |||
assert out.numpy().shape == ref_shape | |||
else: | |||
out = func(*args) | |||
assert out.numpy().shape == ref_shape | |||
print(out.numpy().shape) | |||
inps = [ | |||
np.array([]).astype("float32"), | |||
np.random.randn(2, 0, 3).astype("float32"), | |||
123, | |||
] | |||
for op_name, op in unary_func: | |||
if is_trace: | |||
for sym in [True, False]: | |||
run_test(op, [inps[0],], inps[0].shape, True, sym) | |||
run_test(op, [inps[1],], inps[1].shape, True, sym) | |||
else: | |||
run_test(op, [inps[0],], inps[0].shape, False) | |||
run_test(op, [inps[1],], inps[1].shape, False) | |||
for op_name, op in binary_func: | |||
if is_trace: | |||
for sym in [True, False]: | |||
run_test(op, [inps[0], inps[0]], (inps[0] + inps[0]).shape, True, sym) | |||
run_test(op, [inps[1], inps[1]], (inps[1] + inps[1]).shape, True, sym) | |||
run_test(op, [inps[0], inps[2]], (inps[0] + inps[2]).shape, True, sym) | |||
run_test(op, [inps[1], inps[2]], (inps[1] + inps[2]).shape, True, sym) | |||
else: | |||
run_test(op, [inps[0], inps[0]], (inps[0] + inps[0]).shape, False) | |||
run_test(op, [inps[1], inps[1]], (inps[1] + inps[1]).shape, False) | |||
run_test(op, [inps[0], inps[2]], (inps[0] + inps[2]).shape, False) | |||
run_test(op, [inps[1], inps[2]], (inps[1] + inps[2]).shape, False) |
@@ -19,6 +19,7 @@ from megengine.core._trace_option import use_symbolic_shape | |||
from megengine.core.tensor import megbrain_graph as G | |||
from megengine.core.tensor.utils import astensor1d | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
from megengine.jit import trace | |||
from megengine.utils.network import Network, set_symbolic_shape | |||
from megengine.utils.network_node import VarNode | |||
@@ -177,6 +178,48 @@ def test_reshape(is_varnode): | |||
np.testing.assert_equal(yy.numpy(), y) | |||
@pytest.mark.parametrize("is_trace", [True, False]) | |||
def test_reshape_on_empty_tensor(is_trace): | |||
input1_shape = (100, 0, 1) | |||
output1_shape = (100, 0, 10) | |||
data1 = tensor(np.random.random(input1_shape).astype(np.float32)) | |||
input2_shape = (10, 0) | |||
output2_shape = (0,) | |||
data2 = tensor(np.random.random(input2_shape).astype(np.float32)) | |||
input3_shape = (10, 0, 10) | |||
output3_shape = (0, 1, 2, 3) | |||
data3 = tensor(np.random.random(input3_shape).astype(np.float32)) | |||
def comp(out, target_shp): | |||
assert out._tuple_shape == target_shp | |||
def func(x, shp): | |||
return F.reshape(x, shp) | |||
cases = [ | |||
[data1, output1_shape], | |||
[data2, output2_shape], | |||
[data3, output3_shape], | |||
] | |||
def test(func, inp, comp, target_shp): | |||
out = func(inp, target_shp) | |||
comp(out, target_shp) | |||
if is_trace: | |||
for symbolic in [False, True]: | |||
for inp, target_shp in cases: | |||
func_traced = trace(symbolic=symbolic)(func) | |||
test(func_traced, inp, comp, target_shp) | |||
test(func_traced, inp, comp, target_shp) | |||
test(func_traced, inp, comp, target_shp) | |||
else: | |||
for inp, target_shp in cases: | |||
test(func, inp, comp, target_shp) | |||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||
def test_reshape_shape_inference(is_varnode): | |||
if is_varnode: | |||
@@ -480,6 +523,48 @@ def test_broadcast(is_varnode): | |||
F.broadcast_to(x, (1, 3)) | |||
@pytest.mark.parametrize("is_trace", [True, False]) | |||
def test_broadcast_on_empty_tensor(is_trace): | |||
input1_shape = (100, 0, 1) | |||
output1_shape = (100, 0, 10) | |||
data1 = tensor(np.random.random(input1_shape).astype(np.float32)) | |||
input2_shape = (10, 0) | |||
output2_shape = (10, 10, 0) | |||
data2 = tensor(np.random.random(input2_shape).astype(np.float32)) | |||
input3_shape = (0, 0, 1, 10) | |||
output3_shape = (10, 0, 0, 10, 10) | |||
data3 = tensor(np.random.random(input3_shape).astype(np.float32)) | |||
def comp(out, target_shp): | |||
assert out._tuple_shape == target_shp | |||
def func(x, shp): | |||
return F.broadcast_to(x, shp) | |||
cases = [ | |||
[data1, output1_shape], | |||
[data2, output2_shape], | |||
[data3, output3_shape], | |||
] | |||
def test(func, inp, comp, target_shp): | |||
out = func(inp, target_shp) | |||
comp(out, target_shp) | |||
if is_trace: | |||
for symbolic in [False, True]: | |||
for inp, target_shp in cases: | |||
func_traced = trace(symbolic=symbolic)(func) | |||
test(func_traced, inp, comp, target_shp) | |||
test(func_traced, inp, comp, target_shp) | |||
test(func_traced, inp, comp, target_shp) | |||
else: | |||
for inp, target_shp in cases: | |||
test(func, inp, comp, target_shp) | |||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||
def test_utils_astensor1d(is_varnode): | |||
if is_varnode: | |||
@@ -259,6 +259,10 @@ void Elemwise::perform( | |||
mgb_assert(t.comp_node() == out_cn); | |||
mgb_assert(t.dtype() == out_dt); | |||
} | |||
if (t.shape().is_empty()) { | |||
mgb_assert(dest.empty()); | |||
return; | |||
} | |||
inp_shapes[i] = t.shape(); | |||
} | |||
if (!opr) { | |||
@@ -1064,4 +1064,37 @@ TEST(TestOprBasicArithElemwise, EmptyInputOutputBinary) { | |||
MGB_ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 1, 7})); | |||
} | |||
TEST(TestOprBasicArithElemwise, PerformEmptyIO) { | |||
auto cn = CompNode::load("xpu0"); | |||
HostTensorGenerator<> gen; | |||
auto host_x1 = gen({2, 0, 3, 4}), | |||
host_x2 = gen({1}); | |||
auto dev_x1 = std::make_shared<DeviceTensorND>(cn), | |||
dev_x2 = std::make_shared<DeviceTensorND>(cn); | |||
dev_x1->copy_from(*host_x1); | |||
dev_x2->copy_from(*host_x2); | |||
auto dev_y = std::make_shared<DeviceTensorND>(cn, dev_x1->dtype()); | |||
dev_y->resize(dev_x1->shape()); | |||
auto&& dnn_opr = opr::intl::create_megdnn_opr<megdnn::Elemwise>(cn); | |||
// test unary mode | |||
for (auto mode: {Mode::NEGATE, Mode::EXP, Mode::LOG}) { | |||
SmallVector<DeviceTensorND> inputs = {*dev_x1}; | |||
ASSERT_NO_THROW(opr::Elemwise::perform(mode, *dev_y, inputs, dnn_opr)); | |||
ASSERT_TRUE(dev_y->empty()); | |||
ASSERT_TRUE(dev_y->shape().is_empty()); | |||
MGB_ASSERT_SHAPE_EQ(dev_y->shape(), dev_x1->shape()); | |||
} | |||
// test binary mode | |||
for (auto mode: {Mode::ADD, Mode::MUL, Mode::LT}) { | |||
SmallVector<DeviceTensorND> inputs = {*dev_x1, *dev_x2}; | |||
ASSERT_NO_THROW(opr::Elemwise::perform(mode, *dev_y, inputs, dnn_opr)); | |||
ASSERT_TRUE(dev_y->empty()); | |||
ASSERT_TRUE(dev_y->shape().is_empty()); | |||
MGB_ASSERT_SHAPE_EQ(dev_y->shape(), dev_x1->shape()); | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |