GitOrigin-RevId: 06f5a91105
tags/v0.5.0
@@ -32,6 +32,27 @@ namespace { | |||||
index = opr::TypeCvt::make(index, dtype::Int32()); | index = opr::TypeCvt::make(index, dtype::Int32()); | ||||
} | } | ||||
} | } | ||||
enum IndexingModifyType { | |||||
SET, INCR | |||||
}; | |||||
template<typename Opr> | |||||
struct IndexingModifyTypeGetter {}; | |||||
#define REG(op, type) \ | |||||
template<> \ | |||||
struct IndexingModifyTypeGetter<megdnn::op> { \ | |||||
static constexpr IndexingModifyType value = IndexingModifyType::type; \ | |||||
}; | |||||
REG(IndexingIncrMultiAxisVec, INCR) | |||||
REG(IncrMeshIndexing, INCR) | |||||
REG(BatchedIncrMeshIndexing, INCR) | |||||
REG(IndexingSetMultiAxisVec, SET) | |||||
REG(SetMeshIndexing, SET) | |||||
REG(BatchedSetMeshIndexing, SET) | |||||
#undef REG | |||||
} | } | ||||
namespace mgb { | namespace mgb { | ||||
@@ -371,15 +392,20 @@ void intl::IndexingModifyMultiAxisVecHelper<Opr>::scn_do_execute() { | |||||
auto index_desc = this->make_megdnn_index_desc( | auto index_desc = this->make_megdnn_index_desc( | ||||
inp.first.layout().ndim, ShouldWarnOnScalarIndexer<Opr>::val); | inp.first.layout().ndim, ShouldWarnOnScalarIndexer<Opr>::val); | ||||
if (index_desc.empty()) { | if (index_desc.empty()) { | ||||
if (std::is_same<Opr, megdnn::IndexingSetMultiAxisVec>::value) { | |||||
inp.first.copy_from_fixlayout(inp.second); | |||||
} else { | |||||
static constexpr bool is_incr = std::is_same< | |||||
Opr, megdnn::IndexingIncrMultiAxisVec>::value; | |||||
mgb_assert(is_incr); | |||||
megdnn::AddUpdate* add_update = intl::get_megdnn_global_opr< | |||||
megdnn::AddUpdate>(comp_node()); | |||||
add_update->exec(inp.first.as_megdnn(), inp.second.as_megdnn()); | |||||
using IMT = IndexingModifyType; | |||||
static constexpr auto modify_type = | |||||
IndexingModifyTypeGetter<Opr>::value; | |||||
switch (modify_type) { | |||||
case IMT::SET: { | |||||
inp.first.copy_from_fixlayout(inp.second); | |||||
break; | |||||
} case IMT::INCR: { | |||||
megdnn::AddUpdate* add_update = intl::get_megdnn_global_opr< | |||||
megdnn::AddUpdate>(comp_node()); | |||||
add_update->exec(inp.first.as_megdnn(), inp.second.as_megdnn()); | |||||
break; | |||||
} default: | |||||
mgb_throw(MegBrainError, "bad modify type"); | |||||
} | } | ||||
} else { | } else { | ||||
this->megdnn_opr(*this).exec( | this->megdnn_opr(*this).exec( | ||||
@@ -1165,6 +1165,34 @@ TEST(TestOprIndexing, SetMeshIndexing) { | |||||
checker.run({TensorShape{8, 20, 10, 7, 7}, {1}, {9}, {3, 9, 1, 7, 7}}, | checker.run({TensorShape{8, 20, 10, 7, 7}, {1}, {9}, {3, 9, 1, 7, 7}}, | ||||
opt); | opt); | ||||
} | } | ||||
{ // only interval AxisIndexer given | |||||
using Checker = AutoOprChecker<2, 1>; | |||||
auto make_graph = | |||||
[&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { | |||||
SymbolVar x = inputs[0], val = inputs[1]; | |||||
return {opr::SetMeshIndexing::make( | |||||
x, val, | |||||
{AIdx::make_interval(0, x.make_scalar(1), | |||||
None, x.make_scalar(2))})}; | |||||
}; | |||||
auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { | |||||
dest[0].copy_from(*inp[0]); | |||||
auto value = *inp[1]; | |||||
auto value_iter = megdnn::tensor_iter<float>(value.as_megdnn()).begin(); | |||||
size_t n = dest[0].layout().stride[0]; | |||||
float* raw_ptr = dest[0].ptr<float>(); | |||||
for (size_t i = 0; i < value.shape().total_nr_elems(); ++i) { | |||||
ptrdiff_t offset = (i / n * 2 + 1) * n + i % n; | |||||
*(raw_ptr + offset) = *value_iter; | |||||
++ value_iter; | |||||
} | |||||
}; | |||||
Checker checker{make_graph, fwd}; | |||||
checker.run({TensorShape{11}, {5}}); | |||||
checker.run({TensorShape{6, 7}, {3, 7}}); | |||||
checker.run({TensorShape{4, 7, 1}, {2, 7, 1}}); | |||||
checker.run({TensorShape{7, 1, 1, 2}, {3, 1, 1, 2}}); | |||||
} | |||||
} | } | ||||
#endif // MGB_ENABLE_EXCEPTION | #endif // MGB_ENABLE_EXCEPTION | ||||