@@ -22,6 +22,7 @@ | |||||
#include "megbrain/imperative/ops/elemwise.h" | #include "megbrain/imperative/ops/elemwise.h" | ||||
#include "megbrain/imperative/ops/batch_norm.h" | #include "megbrain/imperative/ops/batch_norm.h" | ||||
#include "megbrain/imperative/ops/broadcast.h" | #include "megbrain/imperative/ops/broadcast.h" | ||||
#include "megbrain/imperative/ops/utility.h" | |||||
namespace py = pybind11; | namespace py = pybind11; | ||||
@@ -113,6 +114,9 @@ void init_ops(py::module m) { | |||||
.def(py::init<>()) | .def(py::init<>()) | ||||
.def_readwrite("offsets", &ParamPackConcat::offsets); | .def_readwrite("offsets", &ParamPackConcat::offsets); | ||||
py::class_<VirtualDep, std::shared_ptr<VirtualDep>, OpDef>(m, "VirtualDep") | |||||
.def(py::init<>()); | |||||
py::class_<CondTake, std::shared_ptr<CondTake>, OpDef>(m, "CondTake") | py::class_<CondTake, std::shared_ptr<CondTake>, OpDef>(m, "CondTake") | ||||
.def(py::init<>()); | .def(py::init<>()); | ||||
@@ -0,0 +1,39 @@ | |||||
/** | |||||
* \file imperative/src/impl/ops/utility.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/imperative/ops/utility.h" | |||||
#include "megbrain/imperative/ops/opr_attr.h" | |||||
#include "megbrain/opr/utility.h" | |||||
#include "../op_trait.h" | |||||
namespace mgb::imperative { | |||||
namespace { | |||||
cg::OperatorNodeBase* virtual_dep_apply_on_var_node( | |||||
const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& graph = inputs[0]->owner_graph(); | |||||
VarNodeArray inps(inputs.begin(), inputs.end()); | |||||
cg::OperatorNodeConfig config; | |||||
cg::OperatorNodeBase* opr = | |||||
graph->insert_opr(std::make_unique<mgb::opr::VirtualDep>( | |||||
inps, config)); | |||||
return opr; | |||||
} | |||||
OP_TRAIT_REG(VirtualDep, VirtualDep, mgb::opr::VirtualDep) | |||||
.apply_on_var_node(virtual_dep_apply_on_var_node) | |||||
.fallback(); | |||||
} // namespace | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(VirtualDep); | |||||
} // namespace mgb::imperative |
@@ -0,0 +1,35 @@ | |||||
/** | |||||
* \file imperative/src/include/megbrain/imperative/ops/utility.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain/imperative/op_def.h" | |||||
#include "megbrain/utils/hash.h" | |||||
namespace mgb::imperative { | |||||
class VirtualDep : public OpDefImplBase<VirtualDep> { | |||||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||||
public: | |||||
VirtualDep() = default; | |||||
size_t hash() const override { | |||||
return reinterpret_cast<size_t>(dyn_typeinfo()); | |||||
} | |||||
bool is_same_st(const Hashable& rhs) const override { | |||||
return true; | |||||
} | |||||
}; | |||||
} // namespace mgb::imperative |