From 50d285fced63253118a7ac2b17d502ed5070c21e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 15 May 2020 17:13:06 +0800 Subject: [PATCH] fix(mgb/opr): fix IndexingModifyMultiAxisVecHelper GitOrigin-RevId: 06f5a9110580116561ff6625956a8f9c630ebba5 --- src/opr/impl/indexing.cpp | 44 +++++++++++++++++++++++++++++++++++--------- src/opr/test/indexing.cpp | 28 ++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 9 deletions(-) diff --git a/src/opr/impl/indexing.cpp b/src/opr/impl/indexing.cpp index 5cb68ca7..cfbdbc6a 100644 --- a/src/opr/impl/indexing.cpp +++ b/src/opr/impl/indexing.cpp @@ -32,6 +32,27 @@ namespace { index = opr::TypeCvt::make(index, dtype::Int32()); } } + + enum IndexingModifyType { + SET, INCR + }; + + template + struct IndexingModifyTypeGetter {}; + +#define REG(op, type) \ + template<> \ + struct IndexingModifyTypeGetter { \ + static constexpr IndexingModifyType value = IndexingModifyType::type; \ + }; +REG(IndexingIncrMultiAxisVec, INCR) +REG(IncrMeshIndexing, INCR) +REG(BatchedIncrMeshIndexing, INCR) +REG(IndexingSetMultiAxisVec, SET) +REG(SetMeshIndexing, SET) +REG(BatchedSetMeshIndexing, SET) +#undef REG + } namespace mgb { @@ -371,15 +392,20 @@ void intl::IndexingModifyMultiAxisVecHelper::scn_do_execute() { auto index_desc = this->make_megdnn_index_desc( inp.first.layout().ndim, ShouldWarnOnScalarIndexer::val); if (index_desc.empty()) { - if (std::is_same::value) { - inp.first.copy_from_fixlayout(inp.second); - } else { - static constexpr bool is_incr = std::is_same< - Opr, megdnn::IndexingIncrMultiAxisVec>::value; - mgb_assert(is_incr); - megdnn::AddUpdate* add_update = intl::get_megdnn_global_opr< - megdnn::AddUpdate>(comp_node()); - add_update->exec(inp.first.as_megdnn(), inp.second.as_megdnn()); + using IMT = IndexingModifyType; + static constexpr auto modify_type = + IndexingModifyTypeGetter::value; + switch (modify_type) { + case IMT::SET: { + inp.first.copy_from_fixlayout(inp.second); + break; + } case IMT::INCR: { + megdnn::AddUpdate* add_update = intl::get_megdnn_global_opr< + megdnn::AddUpdate>(comp_node()); + add_update->exec(inp.first.as_megdnn(), inp.second.as_megdnn()); + break; + } default: + mgb_throw(MegBrainError, "bad modify type"); } } else { this->megdnn_opr(*this).exec( diff --git a/src/opr/test/indexing.cpp b/src/opr/test/indexing.cpp index 143b59ba..cbb22bae 100644 --- a/src/opr/test/indexing.cpp +++ b/src/opr/test/indexing.cpp @@ -1165,6 +1165,34 @@ TEST(TestOprIndexing, SetMeshIndexing) { checker.run({TensorShape{8, 20, 10, 7, 7}, {1}, {9}, {3, 9, 1, 7, 7}}, opt); } + { // only interval AxisIndexer given + using Checker = AutoOprChecker<2, 1>; + auto make_graph = + [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { + SymbolVar x = inputs[0], val = inputs[1]; + return {opr::SetMeshIndexing::make( + x, val, + {AIdx::make_interval(0, x.make_scalar(1), + None, x.make_scalar(2))})}; + }; + auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { + dest[0].copy_from(*inp[0]); + auto value = *inp[1]; + auto value_iter = megdnn::tensor_iter(value.as_megdnn()).begin(); + size_t n = dest[0].layout().stride[0]; + float* raw_ptr = dest[0].ptr(); + for (size_t i = 0; i < value.shape().total_nr_elems(); ++i) { + ptrdiff_t offset = (i / n * 2 + 1) * n + i % n; + *(raw_ptr + offset) = *value_iter; + ++ value_iter; + } + }; + Checker checker{make_graph, fwd}; + checker.run({TensorShape{11}, {5}}); + checker.run({TensorShape{6, 7}, {3, 7}}); + checker.run({TensorShape{4, 7, 1}, {2, 7, 1}}); + checker.run({TensorShape{7, 1, 1, 2}, {3, 1, 1, 2}}); + } } #endif // MGB_ENABLE_EXCEPTION