diff --git a/imperative/src/impl/ops/indexing.cpp b/imperative/src/impl/ops/indexing.cpp new file mode 100644 index 00000000..7d012cf1 --- /dev/null +++ b/imperative/src/impl/ops/indexing.cpp @@ -0,0 +1,76 @@ +/** + * \file imperative/src/impl/ops/indexing.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/imperative/ops/autogen.h" + +#include "../op_trait.h" + +#include "megbrain/opr/indexing.h" + +namespace mgb { +namespace imperative { + +namespace { +namespace indexing_one_hot { + +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, + const SmallVector& input_descs) { + auto& op = def.cast_final_safe(); + + mgb_assert(input_descs.size() == 2, "IndexingOneHot expects two inputs"); + + auto comp_node = input_descs[0].comp_node; + TensorLayout src = input_descs[0].layout, + index = input_descs[1].layout; + + mgb_assert(index.dtype == dtype::Int32(), "index dtype must be int32"); + + if (!src.ndim) { + return {{{{{}, src.dtype}, comp_node}}, false}; + } + + mgb_assert(src.ndim >= 2, "src ndim must be at least 2"); + mgb_assert(src.is_contiguous(), "src should be contiguous"); + mgb_assert(op.axis >= 0 && op.axis < src.ndim, "axis %d not exists in src", op.axis); + + TensorLayout dst = src; + dst.shape[op.axis] = 1; + dst.init_contiguous_stride(); + + if (!index.ndim) { + return {{{dst, comp_node}}, false}; + } + + mgb_assert(index.is_contiguous(), "index should be all contiguous"); + mgb_assert(index.eq_shape(src.remove_axis(op.axis)), "index shape doesn't match src"); + + return {{{dst, comp_node}}, true}; +} + +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 2); + OperatorNodeConfig config{op.make_name()}; + return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param(), config); +} + +OP_TRAIT_REG(IndexingOneHot, IndexingOneHot) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .apply_on_var_node(apply_on_var_node) + .fallback(); + +} // namespace indexing_one_hot +} // anonymous namespace +} // namespace imperative +} // namespace mgb + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp index af07254d..3d596d3e 100644 --- a/imperative/src/impl/ops/specializations.cpp +++ b/imperative/src/impl/ops/specializations.cpp @@ -303,20 +303,6 @@ OP_TRAIT_REG(GroupLocal, GroupLocal) } // namespace namespace { -namespace indexing_one_hot { -auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { - auto&& op = static_cast(def); - mgb_assert(inputs.size() == 2); - OperatorNodeConfig config{op.make_name()}; - return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param(), config); -} -OP_TRAIT_REG(IndexingOneHot, IndexingOneHot) - .apply_on_var_node(apply_on_var_node) - .fallback(); -} // namespace indexing_one_hot -} // namespace - -namespace { namespace indexing_set_one_hot { auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def);