|
|
@@ -568,6 +568,27 @@ TEST(TestOprBasicArith, TypeCvtBool) { |
|
|
|
ASSERT_EQ(TensorShape({3}), host_y.shape()); |
|
|
|
} |
|
|
|
|
|
|
|
TEST(TestOprBasicArith, TypeCvtFromBool) { |
|
|
|
auto graph = ComputingGraph::make(); |
|
|
|
HostTensorGenerator<dtype::Bool> gen; |
|
|
|
auto host_x = gen({2}); |
|
|
|
auto px = host_x->ptr<bool>(); |
|
|
|
px[0] = true; |
|
|
|
px[1] = false; |
|
|
|
|
|
|
|
auto x = opr::Host2DeviceCopy::make(*graph, host_x), |
|
|
|
y = opr::TypeCvt::make(x, dtype::Int32{}); |
|
|
|
HostTensorND host_y; |
|
|
|
auto func = graph->compile({make_callback_copy(y, host_y)}); |
|
|
|
func->execute(); |
|
|
|
|
|
|
|
auto py = host_y.ptr<int>(); |
|
|
|
for (size_t i = 0;i < 2;i ++) { |
|
|
|
ASSERT_EQ(static_cast<int>(px[i]), py[i]); |
|
|
|
} |
|
|
|
ASSERT_EQ(TensorShape({2}), host_y.shape()); |
|
|
|
} |
|
|
|
|
|
|
|
TEST(TestOprBasicArith, ElemwiseMemFwd) { |
|
|
|
auto graph = ComputingGraph::make(); |
|
|
|
graph->options().graph_opt_level = 0; |
|
|
|