GitOrigin-RevId: bbe3ae3fa3
release-1.2
@@ -569,3 +569,9 @@ class AttrOutputNode(OpNode): | |||
def reset(self): | |||
self._rendezvous.reset() | |||
class VirtualDepNode(OpNode): | |||
def __init__(self, vars, device=""): | |||
out = _imperative_rt.virtual_dep(_unwrap(vars), device) | |||
super().__init__(out) |
@@ -25,7 +25,6 @@ from ..core._imperative_rt.ops import ( | |||
RemoteRecv, | |||
RemoteSend, | |||
UniformRNG, | |||
VirtualDep, | |||
) | |||
from ..core._trace_option import set_symbolic_shape | |||
from ..core._wrap import device as as_device | |||
@@ -548,9 +547,10 @@ class trace: | |||
need_reset_nodes.append(opnode) | |||
info.varnode, *in_out_links = opnode.outputs | |||
if require_links and i == 0 and len(io_links) > 0: | |||
info.varnode = apply( | |||
VirtualDep(str(io_links[0].device)), info.varnode, *io_links | |||
)[0] | |||
opnode = G.VirtualDepNode( | |||
[info.varnode, *io_links], str(io_links[0].device) | |||
) | |||
info.varnode = opnode.outputs[0] | |||
io_links = (info.varnode,) | |||
ivars.append(info.varnode) | |||
@@ -1112,11 +1112,11 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): | |||
if require_links and active_trace._lazy_eval_links: | |||
assert len(ivars) > 0, "op should has at least one input" | |||
ivars[0] = apply( | |||
VirtualDep(str(active_trace._lazy_eval_links[0].device)), | |||
ivars[0], | |||
*active_trace._lazy_eval_links, | |||
)[0] | |||
opnode = G.VirtualDepNode( | |||
[ivars[0], *active_trace._lazy_eval_links], | |||
str(active_trace._lazy_eval_links[0].device), | |||
) | |||
ivars[0] = opnode.outputs[0] | |||
active_trace._lazy_eval_links = (ivars[0],) | |||
ovars = apply(op, *ivars) | |||
@@ -15,6 +15,7 @@ | |||
#include "megbrain/serialization/serializer.h" | |||
#include "megbrain/imperative/opr_utility.h" | |||
#include "megbrain/opr/io.h" | |||
#include "megbrain/opr/utility.h" | |||
#include "megbrain/opr/basic_arith.h" | |||
#include "megbrain/imperative.h" | |||
#include "./helper.h" | |||
@@ -562,4 +563,16 @@ void init_graph_rt(py::module m) { | |||
}; | |||
return output_callback(std::move(f), std::move(inputs), p, true); | |||
}); | |||
m.def("virtual_dep", [](std::vector<cg::VarNode*> inputs, std::string device) { | |||
auto&& graph = inputs[0]->owner_graph(); | |||
VarNodeArray inps(inputs.begin(), inputs.end()); | |||
cg::OperatorNodeConfig config; | |||
if (device.length() > 0) { | |||
config.comp_node(CompNode::load(device)); | |||
} | |||
cg::OperatorNodeBase* opr = graph->insert_opr( | |||
std::make_unique<mgb::opr::VirtualDep>(inps, config)); | |||
return opr; | |||
}); | |||
} |
@@ -10,12 +10,10 @@ | |||
*/ | |||
#include "./ops.h" | |||
#include <string> | |||
#include "megbrain/imperative.h" | |||
#include "megbrain/imperative/ops/backward_graph.h" | |||
#include "megbrain/imperative/ops/opr_attr.h" | |||
#include "megbrain/imperative/ops/utility.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
namespace py = pybind11; | |||
@@ -45,9 +43,5 @@ void init_ops(py::module m) { | |||
return self.graph().interpret<py::object>(f, c, inputs); | |||
}); | |||
py::class_<VirtualDep, std::shared_ptr<VirtualDep>, OpDef>(m, "VirtualDep") | |||
.def(py::init<>()) | |||
.def(py::init<std::string>()); | |||
#include "opdef.py.inl" | |||
} |
@@ -1,44 +0,0 @@ | |||
/** | |||
* \file imperative/src/impl/ops/utility.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/imperative/ops/utility.h" | |||
#include <string> | |||
#include "megbrain/comp_node.h" | |||
#include "megbrain/imperative/ops/opr_attr.h" | |||
#include "megbrain/opr/utility.h" | |||
#include "../op_trait.h" | |||
namespace mgb::imperative { | |||
namespace { | |||
cg::OperatorNodeBase* virtual_dep_apply_on_var_node( | |||
const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& graph = inputs[0]->owner_graph(); | |||
auto&& op = def.cast_final_safe<VirtualDep>(); | |||
VarNodeArray inps(inputs.begin(), inputs.end()); | |||
cg::OperatorNodeConfig config; | |||
if (op.device.length() > 0) { | |||
config.comp_node(CompNode::load(op.device)); | |||
} | |||
cg::OperatorNodeBase* opr = | |||
graph->insert_opr(std::make_unique<mgb::opr::VirtualDep>( | |||
inps, config)); | |||
return opr; | |||
} | |||
OP_TRAIT_REG(VirtualDep, VirtualDep, mgb::opr::VirtualDep) | |||
.apply_on_var_node(virtual_dep_apply_on_var_node) | |||
.fallback(); | |||
} // namespace | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(VirtualDep); | |||
} // namespace mgb::imperative |
@@ -1,40 +0,0 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/ops/utility.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <string> | |||
#include "megbrain/graph/operator_node.h" | |||
#include "megbrain/imperative/op_def.h" | |||
#include "megbrain/utils/hash.h" | |||
namespace mgb::imperative { | |||
class VirtualDep : public OpDefImplBase<VirtualDep> { | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
public: | |||
VirtualDep() = default; | |||
VirtualDep(std::string dev) : device(dev) {} | |||
std::string device; | |||
size_t hash() const override { | |||
return reinterpret_cast<size_t>(dyn_typeinfo()); | |||
} | |||
bool is_same_st(const Hashable& rhs) const override { | |||
return true; | |||
} | |||
}; | |||
} // namespace mgb::imperative |