Browse Source

fix(ops): check index layout for IndexingOneHot

GitOrigin-RevId: 80b42a8de9
release-1.6
Megvii Engine Team 3 years ago
parent
commit
73d25779ae
2 changed files with 76 additions and 14 deletions
  1. +76
    -0
      imperative/src/impl/ops/indexing.cpp
  2. +0
    -14
      imperative/src/impl/ops/specializations.cpp

+ 76
- 0
imperative/src/impl/ops/indexing.cpp View File

@@ -0,0 +1,76 @@
/**
* \file imperative/src/impl/ops/indexing.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "megbrain/imperative/ops/autogen.h"

#include "../op_trait.h"

#include "megbrain/opr/indexing.h"

namespace mgb {
namespace imperative {

namespace {
namespace indexing_one_hot {

std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& input_descs) {
auto& op = def.cast_final_safe<IndexingOneHot>();

mgb_assert(input_descs.size() == 2, "IndexingOneHot expects two inputs");

auto comp_node = input_descs[0].comp_node;
TensorLayout src = input_descs[0].layout,
index = input_descs[1].layout;

mgb_assert(index.dtype == dtype::Int32(), "index dtype must be int32");

if (!src.ndim) {
return {{{{{}, src.dtype}, comp_node}}, false};
}

mgb_assert(src.ndim >= 2, "src ndim must be at least 2");
mgb_assert(src.is_contiguous(), "src should be contiguous");
mgb_assert(op.axis >= 0 && op.axis < src.ndim, "axis %d not exists in src", op.axis);

TensorLayout dst = src;
dst.shape[op.axis] = 1;
dst.init_contiguous_stride();

if (!index.ndim) {
return {{{dst, comp_node}}, false};
}

mgb_assert(index.is_contiguous(), "index should be all contiguous");
mgb_assert(index.eq_shape(src.remove_axis(op.axis)), "index shape doesn't match src");

return {{{dst, comp_node}}, true};
}

auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const IndexingOneHot&>(def);
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{op.make_name()};
return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param(), config);
}

OP_TRAIT_REG(IndexingOneHot, IndexingOneHot)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_var_node(apply_on_var_node)
.fallback();

} // namespace indexing_one_hot
} // anonymous namespace
} // namespace imperative
} // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 0
- 14
imperative/src/impl/ops/specializations.cpp View File

@@ -303,20 +303,6 @@ OP_TRAIT_REG(GroupLocal, GroupLocal)
} // namespace

namespace {
namespace indexing_one_hot {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const IndexingOneHot&>(def);
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{op.make_name()};
return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param(), config);
}
OP_TRAIT_REG(IndexingOneHot, IndexingOneHot)
.apply_on_var_node(apply_on_var_node)
.fallback();
} // namespace indexing_one_hot
} // namespace

namespace {
namespace indexing_set_one_hot {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const IndexingSetOneHot&>(def);


Loading…
Cancel
Save