GitOrigin-RevId: bb3a59380e
tags/v1.0.0-rc1
@@ -12,6 +12,7 @@ import weakref | |||
from concurrent.futures import Future, ThreadPoolExecutor | |||
from .. import _imperative_rt | |||
from .._imperative_rt.ops import BackwardGraph | |||
from .._wrap import device as as_device | |||
from ..ops.builtin import OpDef | |||
from .core import OpBase, TensorBase, apply | |||
@@ -131,6 +132,13 @@ def _(op: OpDef, *args: VarNode): | |||
return _wrap(outputs) | |||
@apply.register() | |||
def _(op: BackwardGraph, *args: VarNode): | |||
assert args | |||
graph = args[0].graph | |||
return op.interpret(lambda op, args: apply(op, *args), graph.make_const, args) | |||
def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=None): | |||
outputs = _imperative_rt.input_callback( | |||
callback, as_device(device).to_c(), dtype, shape, _unwrap(args), graph=graph | |||
@@ -40,6 +40,18 @@ void init_ops(py::module m) { | |||
attr.param.insert(attr.param.end(), s.begin(), s.end()); | |||
}); | |||
py::class_<BackwardGraph, std::shared_ptr<BackwardGraph>, OpDef>(m, "BackwardGraph") | |||
.def("interpret", [](BackwardGraph& self, py::object pyf, py::object pyc, | |||
const mgb::SmallVector<py::object>& inputs) { | |||
auto f = [pyf](OpDef& op, const mgb::SmallVector<py::object>& inputs) { | |||
return py::cast<mgb::SmallVector<py::object>>(pyf(op.copy(), inputs)); | |||
}; | |||
auto c = [pyc](const TensorPtr& tensor) { | |||
return pyc(tensor->dev_tensor()); | |||
}; | |||
return self.graph().interpret<py::object>(f, c, inputs); | |||
}); | |||
py::class_<GetVarShape, std::shared_ptr<GetVarShape>, OpDef>(m, "GetVarShape") | |||
.def(py::init()); | |||
@@ -98,7 +110,6 @@ void init_ops(py::module m) { | |||
.def(py::init<>()) | |||
.def_readwrite("offsets", &ParamPackConcat::offsets); | |||
py::class_<BackwardGraph, std::shared_ptr<BackwardGraph>, OpDef>(m, "BackwardGraph"); | |||
py::class_<CondTake, std::shared_ptr<CondTake>, OpDef>(m, "CondTake") | |||
.def(py::init<>()); | |||
@@ -18,34 +18,10 @@ namespace imperative { | |||
SmallVector<TensorPtr> | |||
BackwardGraph::InternalGraph::apply( | |||
const SmallVector<TensorPtr>& inputs) const { | |||
ThinHashMap<size_t, TensorPtr> node2tensor; | |||
auto&& input_nodes = this->inputs; | |||
mgb_assert(inputs.size() == input_nodes.size()); | |||
for (size_t i = 0; i < inputs.size(); ++ i) { | |||
node2tensor[input_nodes[i]] = inputs[i]; | |||
} | |||
for (auto &&i : constants) { | |||
node2tensor[i.first] = i.second; | |||
} | |||
for (size_t i = 0; i < exprs.size(); ++ i) { | |||
auto&& expr = exprs[i]; | |||
SmallVector<TensorPtr> inputs; | |||
for (auto &&in : std::get<1>(expr)) { | |||
inputs.push_back(node2tensor.at(in)); | |||
} | |||
auto outputs = OpDef::apply_on_physical_tensor( | |||
*std::get<0>(expr), inputs); | |||
auto output_nodes = std::get<2>(expr); | |||
mgb_assert(outputs.size() == output_nodes.size()); | |||
for (size_t i = 0; i < outputs.size(); ++ i) { | |||
node2tensor[output_nodes[i]] = outputs[i]; | |||
} | |||
} | |||
SmallVector<TensorPtr> ret; | |||
for (auto &&i : outputs) { | |||
ret.push_back(node2tensor.at(i)); | |||
} | |||
return ret; | |||
return interpret<TensorPtr>( | |||
&OpDef::apply_on_physical_tensor, | |||
[](const TensorPtr& x) {return x;}, | |||
inputs); | |||
} | |||
SmallVector<LogicalTensorDesc> | |||
@@ -40,6 +40,37 @@ public: | |||
SmallVector<LogicalTensorDesc> | |||
infer_attrs(const SmallVector<LogicalTensorDesc>& inputs) const; | |||
template <typename T, typename F, typename C> | |||
SmallVector<T> interpret(F&& f, C&& c, const SmallVector<T>& inputs) const { | |||
ThinHashMap<size_t, T> node2tensor; | |||
auto&& input_nodes = this->inputs; | |||
mgb_assert(inputs.size() == input_nodes.size()); | |||
for (size_t i = 0; i < inputs.size(); ++ i) { | |||
node2tensor[input_nodes[i]] = inputs[i]; | |||
} | |||
for (auto &&i : constants) { | |||
node2tensor[i.first] = c(i.second); | |||
} | |||
for (size_t i = 0; i < exprs.size(); ++ i) { | |||
auto&& expr = exprs[i]; | |||
SmallVector<T> inputs; | |||
for (auto &&in : std::get<1>(expr)) { | |||
inputs.push_back(node2tensor.at(in)); | |||
} | |||
auto&& outputs = f(*std::get<0>(expr), std::move(inputs)); | |||
auto&& output_nodes = std::get<2>(expr); | |||
mgb_assert(outputs.size() == output_nodes.size()); | |||
for (size_t i = 0; i < outputs.size(); ++ i) { | |||
node2tensor[output_nodes[i]] = std::move(outputs[i]); | |||
} | |||
} | |||
SmallVector<T> ret; | |||
for (auto &&i : outputs) { | |||
ret.push_back(node2tensor.at(i)); | |||
} | |||
return ret; | |||
} | |||
}; | |||
const InternalGraph& graph() const { | |||