|
|
@@ -649,6 +649,17 @@ namespace { |
|
|
|
> TernaryTraitTypes; |
|
|
|
TYPED_TEST_CASE(TestOprBasicArithTernaryElemwise, TernaryTraitTypes); |
|
|
|
|
|
|
|
::testing::AssertionResult assert_shape_equal(const TensorShape& v0, |
|
|
|
const TensorShape& v1) { |
|
|
|
if (v0.eq_shape(v1)) |
|
|
|
return ::testing::AssertionSuccess() |
|
|
|
<< v0.to_string() << " == " << v1.to_string(); |
|
|
|
else |
|
|
|
return ::testing::AssertionFailure() |
|
|
|
<< v0.to_string() << " != " << v1.to_string(); |
|
|
|
} |
|
|
|
#define ASSERT_SHAPE_EQ(v0, v1) ASSERT_TRUE(assert_shape_equal(v0, v1)) |
|
|
|
|
|
|
|
} // anonymous namespace |
|
|
|
|
|
|
|
template<typename Trait, typename dtype> |
|
|
@@ -950,4 +961,58 @@ TEST(TestLayoutUtil, CollectiveCollapse) { |
|
|
|
check(cc_res5, std_res5); |
|
|
|
} |
|
|
|
|
|
|
|
TEST(TestOprBasicArithElemwise, EmptyInputOutputUnary) { |
|
|
|
HostTensorGenerator<> gen; |
|
|
|
auto graph = ComputingGraph::make(); |
|
|
|
auto host_x = gen({3, 0, 1, 3}); |
|
|
|
auto x = opr::Host2DeviceCopy::make(*graph, host_x), |
|
|
|
y = opr::Elemwise::make( |
|
|
|
{x}, opr::Elemwise::Param(opr::Elemwise::Param::Mode::RELU)); |
|
|
|
HostTensorND host_y; |
|
|
|
auto func = graph->compile({make_callback_copy(y, host_y)}); |
|
|
|
|
|
|
|
ASSERT_NO_THROW(func->execute().wait()); |
|
|
|
ASSERT_TRUE(host_y.empty()); |
|
|
|
ASSERT_TRUE(host_y.shape().is_empty()); |
|
|
|
ASSERT_SHAPE_EQ(host_y.shape(), TensorShape({3, 0, 1, 3})); |
|
|
|
} |
|
|
|
|
|
|
|
TEST(TestOprBasicArithElemwise, EmptyInputOutputBinary) { |
|
|
|
HostTensorGenerator<> gen; |
|
|
|
auto graph = ComputingGraph::make(); |
|
|
|
auto host_x = gen({0, 8, 1, 7}), host_y = gen({0, 8, 1, 7}); |
|
|
|
|
|
|
|
auto x = opr::Host2DeviceCopy::make(*graph, host_x), |
|
|
|
y = opr::Host2DeviceCopy::make(*graph, host_y), |
|
|
|
z = x + y; |
|
|
|
HostTensorND host_z; |
|
|
|
auto func = graph->compile({make_callback_copy(z, host_z)}); |
|
|
|
|
|
|
|
// Invalid broadcast |
|
|
|
host_y->resize({0, 9, 1, 7}); |
|
|
|
ASSERT_ANY_THROW(func->execute().wait()); |
|
|
|
|
|
|
|
// Broadcast to 0 |
|
|
|
host_y->resize({1, 8, 0, 7}); |
|
|
|
ASSERT_NO_THROW(func->execute().wait()); |
|
|
|
ASSERT_TRUE(host_z.empty()); |
|
|
|
ASSERT_TRUE(host_z.shape().is_empty()); |
|
|
|
ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 0, 7})); |
|
|
|
|
|
|
|
// Broadcast to 0 (2) |
|
|
|
host_y->resize({2, 8, 1, 7}); |
|
|
|
ASSERT_NO_THROW(func->execute().wait()); |
|
|
|
ASSERT_TRUE(host_z.empty()); |
|
|
|
ASSERT_TRUE(host_z.shape().is_empty()); |
|
|
|
ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 1, 7})); |
|
|
|
|
|
|
|
// Scalar broadcast |
|
|
|
z = x + x.make_scalar(1.f); |
|
|
|
func = graph->compile({make_callback_copy(z, host_z)}); |
|
|
|
ASSERT_NO_THROW(func->execute().wait()); |
|
|
|
ASSERT_TRUE(host_z.empty()); |
|
|
|
ASSERT_TRUE(host_z.shape().is_empty()); |
|
|
|
ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 1, 7})); |
|
|
|
} |
|
|
|
|
|
|
|
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |