|
|
@@ -32,6 +32,27 @@ namespace { |
|
|
|
index = opr::TypeCvt::make(index, dtype::Int32()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
enum IndexingModifyType { |
|
|
|
SET, INCR |
|
|
|
}; |
|
|
|
|
|
|
|
template<typename Opr> |
|
|
|
struct IndexingModifyTypeGetter {}; |
|
|
|
|
|
|
|
#define REG(op, type) \ |
|
|
|
template<> \ |
|
|
|
struct IndexingModifyTypeGetter<megdnn::op> { \ |
|
|
|
static constexpr IndexingModifyType value = IndexingModifyType::type; \ |
|
|
|
}; |
|
|
|
REG(IndexingIncrMultiAxisVec, INCR) |
|
|
|
REG(IncrMeshIndexing, INCR) |
|
|
|
REG(BatchedIncrMeshIndexing, INCR) |
|
|
|
REG(IndexingSetMultiAxisVec, SET) |
|
|
|
REG(SetMeshIndexing, SET) |
|
|
|
REG(BatchedSetMeshIndexing, SET) |
|
|
|
#undef REG |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
namespace mgb { |
|
|
@@ -371,15 +392,20 @@ void intl::IndexingModifyMultiAxisVecHelper<Opr>::scn_do_execute() { |
|
|
|
auto index_desc = this->make_megdnn_index_desc( |
|
|
|
inp.first.layout().ndim, ShouldWarnOnScalarIndexer<Opr>::val); |
|
|
|
if (index_desc.empty()) { |
|
|
|
if (std::is_same<Opr, megdnn::IndexingSetMultiAxisVec>::value) { |
|
|
|
inp.first.copy_from_fixlayout(inp.second); |
|
|
|
} else { |
|
|
|
static constexpr bool is_incr = std::is_same< |
|
|
|
Opr, megdnn::IndexingIncrMultiAxisVec>::value; |
|
|
|
mgb_assert(is_incr); |
|
|
|
megdnn::AddUpdate* add_update = intl::get_megdnn_global_opr< |
|
|
|
megdnn::AddUpdate>(comp_node()); |
|
|
|
add_update->exec(inp.first.as_megdnn(), inp.second.as_megdnn()); |
|
|
|
using IMT = IndexingModifyType; |
|
|
|
static constexpr auto modify_type = |
|
|
|
IndexingModifyTypeGetter<Opr>::value; |
|
|
|
switch (modify_type) { |
|
|
|
case IMT::SET: { |
|
|
|
inp.first.copy_from_fixlayout(inp.second); |
|
|
|
break; |
|
|
|
} case IMT::INCR: { |
|
|
|
megdnn::AddUpdate* add_update = intl::get_megdnn_global_opr< |
|
|
|
megdnn::AddUpdate>(comp_node()); |
|
|
|
add_update->exec(inp.first.as_megdnn(), inp.second.as_megdnn()); |
|
|
|
break; |
|
|
|
} default: |
|
|
|
mgb_throw(MegBrainError, "bad modify type"); |
|
|
|
} |
|
|
|
} else { |
|
|
|
this->megdnn_opr(*this).exec( |
|
|
|