|
|
@@ -1189,6 +1189,43 @@ TEST(TestTensorManip, SetSubtensor) { |
|
|
|
run(mkshp({18, 5, 2, 3}), opt); |
|
|
|
} |
|
|
|
|
|
|
|
TEST(TestTensorManip, SetSubtensorCheckByShapeInfer) { |
|
|
|
HostTensorGenerator<> gen; |
|
|
|
HostTensorGenerator<dtype::Int32> gen_int; |
|
|
|
auto host_x = gen({12}), host_sub = gen({1}), host_idx = gen_int({1}); |
|
|
|
host_idx->ptr<int>()[0] = 13; |
|
|
|
auto graph = ComputingGraph::make(); |
|
|
|
using Ad = opr::Subtensor::AxisIndexer; |
|
|
|
auto x = opr::Host2DeviceCopy::make(*graph, host_x), |
|
|
|
sub = opr::Host2DeviceCopy::make(*graph, host_sub); |
|
|
|
auto idx1 = Ad::make_index(0, |
|
|
|
opr::ImmutableTensor::make(*graph, *host_idx)), |
|
|
|
idx2 = Ad::make_index(0, opr::Host2DeviceCopy::make(*graph, host_idx)); |
|
|
|
|
|
|
|
MGB_MARK_USED_VAR(x); |
|
|
|
MGB_MARK_USED_VAR(sub); |
|
|
|
MGB_MARK_USED_VAR(idx1); |
|
|
|
MGB_MARK_USED_VAR(idx2); |
|
|
|
ASSERT_THROW(opr::SetSubtensor::make(x, sub, {idx1}), MegBrainError); |
|
|
|
ASSERT_THROW(opr::SetSubtensor::make(x, sub, {idx2}), MegBrainError); |
|
|
|
} |
|
|
|
|
|
|
|
TEST(TestTensorManip, SetSubtensorShapeInfer) { |
|
|
|
HostTensorGenerator<> gen; |
|
|
|
HostTensorGenerator<dtype::Int32> gen_int; |
|
|
|
auto host_x = gen({12}), host_sub = gen({1}), host_idx = gen_int({1}); |
|
|
|
host_idx->ptr<int>()[0] = 13; |
|
|
|
auto graph = ComputingGraph::make(); |
|
|
|
auto&& mgr = graph->static_infer_manager(); |
|
|
|
using Ad = opr::Subtensor::AxisIndexer; |
|
|
|
auto x = opr::Host2DeviceCopy::make(*graph, host_x), |
|
|
|
sub = opr::Host2DeviceCopy::make(*graph, host_sub), |
|
|
|
index = opr::Host2DeviceCopy::make_no_value_infer(*graph, host_idx); |
|
|
|
auto rt_static_idx = Ad::make_index(0, index * 2); |
|
|
|
auto y = opr::SetSubtensor::make(x, sub, {rt_static_idx}); |
|
|
|
ASSERT_TRUE(mgr.infer_shape_fallible(y.node())); |
|
|
|
} |
|
|
|
|
|
|
|
TEST(TestTensorManip, SetSubtensorDynIdx) { |
|
|
|
HostTensorGenerator<> gen; |
|
|
|
auto host_x = gen({12}), host_sub = gen({1}), |
|
|
|