GitOrigin-RevId: bbe3ae3fa3
release-1.2
@@ -569,3 +569,9 @@ class AttrOutputNode(OpNode): | |||||
def reset(self): | def reset(self): | ||||
self._rendezvous.reset() | self._rendezvous.reset() | ||||
class VirtualDepNode(OpNode): | |||||
def __init__(self, vars, device=""): | |||||
out = _imperative_rt.virtual_dep(_unwrap(vars), device) | |||||
super().__init__(out) |
@@ -25,7 +25,6 @@ from ..core._imperative_rt.ops import ( | |||||
RemoteRecv, | RemoteRecv, | ||||
RemoteSend, | RemoteSend, | ||||
UniformRNG, | UniformRNG, | ||||
VirtualDep, | |||||
) | ) | ||||
from ..core._trace_option import set_symbolic_shape | from ..core._trace_option import set_symbolic_shape | ||||
from ..core._wrap import device as as_device | from ..core._wrap import device as as_device | ||||
@@ -548,9 +547,10 @@ class trace: | |||||
need_reset_nodes.append(opnode) | need_reset_nodes.append(opnode) | ||||
info.varnode, *in_out_links = opnode.outputs | info.varnode, *in_out_links = opnode.outputs | ||||
if require_links and i == 0 and len(io_links) > 0: | if require_links and i == 0 and len(io_links) > 0: | ||||
info.varnode = apply( | |||||
VirtualDep(str(io_links[0].device)), info.varnode, *io_links | |||||
)[0] | |||||
opnode = G.VirtualDepNode( | |||||
[info.varnode, *io_links], str(io_links[0].device) | |||||
) | |||||
info.varnode = opnode.outputs[0] | |||||
io_links = (info.varnode,) | io_links = (info.varnode,) | ||||
ivars.append(info.varnode) | ivars.append(info.varnode) | ||||
@@ -1112,11 +1112,11 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): | |||||
if require_links and active_trace._lazy_eval_links: | if require_links and active_trace._lazy_eval_links: | ||||
assert len(ivars) > 0, "op should has at least one input" | assert len(ivars) > 0, "op should has at least one input" | ||||
ivars[0] = apply( | |||||
VirtualDep(str(active_trace._lazy_eval_links[0].device)), | |||||
ivars[0], | |||||
*active_trace._lazy_eval_links, | |||||
)[0] | |||||
opnode = G.VirtualDepNode( | |||||
[ivars[0], *active_trace._lazy_eval_links], | |||||
str(active_trace._lazy_eval_links[0].device), | |||||
) | |||||
ivars[0] = opnode.outputs[0] | |||||
active_trace._lazy_eval_links = (ivars[0],) | active_trace._lazy_eval_links = (ivars[0],) | ||||
ovars = apply(op, *ivars) | ovars = apply(op, *ivars) | ||||
@@ -15,6 +15,7 @@ | |||||
#include "megbrain/serialization/serializer.h" | #include "megbrain/serialization/serializer.h" | ||||
#include "megbrain/imperative/opr_utility.h" | #include "megbrain/imperative/opr_utility.h" | ||||
#include "megbrain/opr/io.h" | #include "megbrain/opr/io.h" | ||||
#include "megbrain/opr/utility.h" | |||||
#include "megbrain/opr/basic_arith.h" | #include "megbrain/opr/basic_arith.h" | ||||
#include "megbrain/imperative.h" | #include "megbrain/imperative.h" | ||||
#include "./helper.h" | #include "./helper.h" | ||||
@@ -562,4 +563,16 @@ void init_graph_rt(py::module m) { | |||||
}; | }; | ||||
return output_callback(std::move(f), std::move(inputs), p, true); | return output_callback(std::move(f), std::move(inputs), p, true); | ||||
}); | }); | ||||
m.def("virtual_dep", [](std::vector<cg::VarNode*> inputs, std::string device) { | |||||
auto&& graph = inputs[0]->owner_graph(); | |||||
VarNodeArray inps(inputs.begin(), inputs.end()); | |||||
cg::OperatorNodeConfig config; | |||||
if (device.length() > 0) { | |||||
config.comp_node(CompNode::load(device)); | |||||
} | |||||
cg::OperatorNodeBase* opr = graph->insert_opr( | |||||
std::make_unique<mgb::opr::VirtualDep>(inps, config)); | |||||
return opr; | |||||
}); | |||||
} | } |
@@ -10,12 +10,10 @@ | |||||
*/ | */ | ||||
#include "./ops.h" | #include "./ops.h" | ||||
#include <string> | |||||
#include "megbrain/imperative.h" | #include "megbrain/imperative.h" | ||||
#include "megbrain/imperative/ops/backward_graph.h" | #include "megbrain/imperative/ops/backward_graph.h" | ||||
#include "megbrain/imperative/ops/opr_attr.h" | #include "megbrain/imperative/ops/opr_attr.h" | ||||
#include "megbrain/imperative/ops/utility.h" | |||||
#include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
namespace py = pybind11; | namespace py = pybind11; | ||||
@@ -45,9 +43,5 @@ void init_ops(py::module m) { | |||||
return self.graph().interpret<py::object>(f, c, inputs); | return self.graph().interpret<py::object>(f, c, inputs); | ||||
}); | }); | ||||
py::class_<VirtualDep, std::shared_ptr<VirtualDep>, OpDef>(m, "VirtualDep") | |||||
.def(py::init<>()) | |||||
.def(py::init<std::string>()); | |||||
#include "opdef.py.inl" | #include "opdef.py.inl" | ||||
} | } |
@@ -1,44 +0,0 @@ | |||||
/** | |||||
* \file imperative/src/impl/ops/utility.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/imperative/ops/utility.h" | |||||
#include <string> | |||||
#include "megbrain/comp_node.h" | |||||
#include "megbrain/imperative/ops/opr_attr.h" | |||||
#include "megbrain/opr/utility.h" | |||||
#include "../op_trait.h" | |||||
namespace mgb::imperative { | |||||
namespace { | |||||
cg::OperatorNodeBase* virtual_dep_apply_on_var_node( | |||||
const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& graph = inputs[0]->owner_graph(); | |||||
auto&& op = def.cast_final_safe<VirtualDep>(); | |||||
VarNodeArray inps(inputs.begin(), inputs.end()); | |||||
cg::OperatorNodeConfig config; | |||||
if (op.device.length() > 0) { | |||||
config.comp_node(CompNode::load(op.device)); | |||||
} | |||||
cg::OperatorNodeBase* opr = | |||||
graph->insert_opr(std::make_unique<mgb::opr::VirtualDep>( | |||||
inps, config)); | |||||
return opr; | |||||
} | |||||
OP_TRAIT_REG(VirtualDep, VirtualDep, mgb::opr::VirtualDep) | |||||
.apply_on_var_node(virtual_dep_apply_on_var_node) | |||||
.fallback(); | |||||
} // namespace | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(VirtualDep); | |||||
} // namespace mgb::imperative |
@@ -1,40 +0,0 @@ | |||||
/** | |||||
* \file imperative/src/include/megbrain/imperative/ops/utility.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <string> | |||||
#include "megbrain/graph/operator_node.h" | |||||
#include "megbrain/imperative/op_def.h" | |||||
#include "megbrain/utils/hash.h" | |||||
namespace mgb::imperative { | |||||
class VirtualDep : public OpDefImplBase<VirtualDep> { | |||||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||||
public: | |||||
VirtualDep() = default; | |||||
VirtualDep(std::string dev) : device(dev) {} | |||||
std::string device; | |||||
size_t hash() const override { | |||||
return reinterpret_cast<size_t>(dyn_typeinfo()); | |||||
} | |||||
bool is_same_st(const Hashable& rhs) const override { | |||||
return true; | |||||
} | |||||
}; | |||||
} // namespace mgb::imperative |