GitOrigin-RevId: f15b1d45a1
release-1.6
@@ -291,8 +291,10 @@ template<class Opr> | |||||
cg::OperatorNodeBase::NodeProp* | cg::OperatorNodeBase::NodeProp* | ||||
IndexingMultiAxisVecBase<Opr>::do_make_node_prop() const { | IndexingMultiAxisVecBase<Opr>::do_make_node_prop() const { | ||||
auto prop = Super::do_make_node_prop(); | auto prop = Super::do_make_node_prop(); | ||||
using DT = NodeProp::DepType; | |||||
// TODO: should also allow input shape is empty if any | // TODO: should also allow input shape is empty if any | ||||
// indexer's shape is empty | // indexer's shape is empty | ||||
prop->add_dep_type_existing_var(input(0), DT::VALUE_ALLOW_EMPTY); | |||||
for (auto i: m_input2idxonly_axis_indexer) { | for (auto i: m_input2idxonly_axis_indexer) { | ||||
if (i) { | if (i) { | ||||
prop->add_dep_type_existing_var( | prop->add_dep_type_existing_var( | ||||
@@ -415,7 +417,7 @@ void intl::IndexingModifyMultiAxisVecHelper<Opr>::scn_do_execute() { | |||||
auto inp = this->fancy_indexing_get_tensors_for_modify_in_scn_do_execute(); | auto inp = this->fancy_indexing_get_tensors_for_modify_in_scn_do_execute(); | ||||
auto index_desc = this->make_megdnn_index_desc( | auto index_desc = this->make_megdnn_index_desc( | ||||
inp.first.layout().ndim, ShouldWarnOnScalarIndexer<Opr>::val); | inp.first.layout().ndim, ShouldWarnOnScalarIndexer<Opr>::val); | ||||
if (index_desc.second){ | |||||
if (inp.first.shape().is_empty() || index_desc.second){ | |||||
mgb_assert(inp.second.shape().is_empty()); | mgb_assert(inp.second.shape().is_empty()); | ||||
return; | return; | ||||
} | } | ||||
@@ -476,10 +478,19 @@ MGB_IMPL_FANCY_INDEXING_OPR_GET( | |||||
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | ||||
); | ); | ||||
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY( | MGB_IMPL_FANCY_INDEXING_OPR_MODIFY( | ||||
IndexingSetMultiAxisVec, "indexing_set_multi_axis_vec", false); | |||||
IndexingSetMultiAxisVec, "indexing_set_multi_axis_vec", false, | |||||
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||||
); | |||||
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY( | MGB_IMPL_FANCY_INDEXING_OPR_MODIFY( | ||||
IndexingIncrMultiAxisVec, "indexing_incr_multi_axis_vec", false); | IndexingIncrMultiAxisVec, "indexing_incr_multi_axis_vec", false); | ||||
IndexingSetMultiAxisVec::NodeProp* IndexingSetMultiAxisVec::do_make_node_prop() const { | |||||
auto prop = Super::do_make_node_prop(); | |||||
prop->add_dep_type_existing_var(input(0), | |||||
NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||||
return prop; | |||||
} | |||||
#if MGB_ENABLE_GRAD | #if MGB_ENABLE_GRAD | ||||
MGB_IMPL_OPR_GRAD(IndexingMultiAxisVec) { | MGB_IMPL_OPR_GRAD(IndexingMultiAxisVec) { | ||||
if (wrt_idx) | if (wrt_idx) | ||||
@@ -132,11 +132,11 @@ namespace intl { | |||||
void init_output_static_infer_desc() override final; | void init_output_static_infer_desc() override final; | ||||
void scn_do_execute() override final; | void scn_do_execute() override final; | ||||
NodeProp* do_make_node_prop() const override; | |||||
void add_input_layout_constraint() override final; | void add_input_layout_constraint() override final; | ||||
protected: | |||||
using Super::Super; | |||||
protected: | |||||
using Super::Super; | |||||
NodeProp* do_make_node_prop() const override; | |||||
}; | }; | ||||
} // namespace intl | } // namespace intl | ||||
@@ -158,6 +158,7 @@ public: | |||||
MGB_DEFINE_OPR_CLASS(IndexingSetMultiAxisVec, | MGB_DEFINE_OPR_CLASS(IndexingSetMultiAxisVec, | ||||
intl::IndexingModifyMultiAxisVecHelper<megdnn::IndexingSetMultiAxisVec> | intl::IndexingModifyMultiAxisVecHelper<megdnn::IndexingSetMultiAxisVec> | ||||
) // { | ) // { | ||||
NodeProp* do_make_node_prop() const override; | |||||
public: | public: | ||||
MGB_DECL_FANCY_INDEXING_OPR_MODIFY(IndexingSetMultiAxisVec); | MGB_DECL_FANCY_INDEXING_OPR_MODIFY(IndexingSetMultiAxisVec); | ||||
@@ -241,13 +241,15 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(_opr) | |||||
const OperatorNodeConfig &config = {}, \ | const OperatorNodeConfig &config = {}, \ | ||||
const InputTensorReplacer &input_tensor_replacer = {}) | const InputTensorReplacer &input_tensor_replacer = {}) | ||||
#define MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(_opr, _name, _require_scalar_index) \ | |||||
#define MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(_opr, _name, _require_scalar_index, \ | |||||
ctor_body...) \ | |||||
_opr::_opr(VarNode *inp, VarNode *value, const IndexDesc &desc, \ | _opr::_opr(VarNode *inp, VarNode *value, const IndexDesc &desc, \ | ||||
const OperatorNodeConfig &config, \ | const OperatorNodeConfig &config, \ | ||||
const InputTensorReplacer &input_tensor_replacer): \ | const InputTensorReplacer &input_tensor_replacer): \ | ||||
Super({inp->owner_graph(), config, _name, {inp, value}}, \ | Super({inp->owner_graph(), config, _name, {inp, value}}, \ | ||||
inp, value, desc, _require_scalar_index, input_tensor_replacer) \ | inp, value, desc, _require_scalar_index, input_tensor_replacer) \ | ||||
{ \ | { \ | ||||
ctor_body; \ | |||||
} \ | } \ | ||||
SymbolVar _opr::make(SymbolVar inp, SymbolVar value, const IndexDesc &desc, \ | SymbolVar _opr::make(SymbolVar inp, SymbolVar value, const IndexDesc &desc, \ | ||||
const OperatorNodeConfig &config, \ | const OperatorNodeConfig &config, \ | ||||