|
|
@@ -546,6 +546,28 @@ TEST(TestOprBasicArith, TypeCvt) { |
|
|
|
ASSERT_EQ(TensorShape({3, 0}), host_y.shape()); |
|
|
|
} |
|
|
|
|
|
|
|
TEST(TestOprBasicArith, TypeCvtBool) { |
|
|
|
auto graph = ComputingGraph::make(); |
|
|
|
HostTensorGenerator<dtype::Int32> gen; |
|
|
|
auto host_x = gen({3}); |
|
|
|
auto px = host_x->ptr<int>(); |
|
|
|
px[0] = -1; |
|
|
|
px[1] = 0; |
|
|
|
px[2] = 1; |
|
|
|
|
|
|
|
auto x = opr::Host2DeviceCopy::make(*graph, host_x), |
|
|
|
y = opr::TypeCvt::make(x, dtype::Bool{}); |
|
|
|
HostTensorND host_y; |
|
|
|
auto func = graph->compile({make_callback_copy(y, host_y)}); |
|
|
|
func->execute(); |
|
|
|
|
|
|
|
auto py = host_y.ptr<bool>(); |
|
|
|
for (size_t i = 0;i < 3;i ++) { |
|
|
|
ASSERT_EQ(static_cast<bool>(px[i]), py[i]); |
|
|
|
} |
|
|
|
ASSERT_EQ(TensorShape({3}), host_y.shape()); |
|
|
|
} |
|
|
|
|
|
|
|
TEST(TestOprBasicArith, ElemwiseMemFwd) { |
|
|
|
auto graph = ComputingGraph::make(); |
|
|
|
graph->options().graph_opt_level = 0; |
|
|
|