GitOrigin-RevId: bb3a59380e
tags/v1.0.0-rc1
@@ -12,6 +12,7 @@ import weakref | |||||
from concurrent.futures import Future, ThreadPoolExecutor | from concurrent.futures import Future, ThreadPoolExecutor | ||||
from .. import _imperative_rt | from .. import _imperative_rt | ||||
from .._imperative_rt.ops import BackwardGraph | |||||
from .._wrap import device as as_device | from .._wrap import device as as_device | ||||
from ..ops.builtin import OpDef | from ..ops.builtin import OpDef | ||||
from .core import OpBase, TensorBase, apply | from .core import OpBase, TensorBase, apply | ||||
@@ -131,6 +132,13 @@ def _(op: OpDef, *args: VarNode): | |||||
return _wrap(outputs) | return _wrap(outputs) | ||||
@apply.register() | |||||
def _(op: BackwardGraph, *args: VarNode): | |||||
assert args | |||||
graph = args[0].graph | |||||
return op.interpret(lambda op, args: apply(op, *args), graph.make_const, args) | |||||
def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=None): | def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=None): | ||||
outputs = _imperative_rt.input_callback( | outputs = _imperative_rt.input_callback( | ||||
callback, as_device(device).to_c(), dtype, shape, _unwrap(args), graph=graph | callback, as_device(device).to_c(), dtype, shape, _unwrap(args), graph=graph | ||||
@@ -40,6 +40,18 @@ void init_ops(py::module m) { | |||||
attr.param.insert(attr.param.end(), s.begin(), s.end()); | attr.param.insert(attr.param.end(), s.begin(), s.end()); | ||||
}); | }); | ||||
py::class_<BackwardGraph, std::shared_ptr<BackwardGraph>, OpDef>(m, "BackwardGraph") | |||||
.def("interpret", [](BackwardGraph& self, py::object pyf, py::object pyc, | |||||
const mgb::SmallVector<py::object>& inputs) { | |||||
auto f = [pyf](OpDef& op, const mgb::SmallVector<py::object>& inputs) { | |||||
return py::cast<mgb::SmallVector<py::object>>(pyf(op.copy(), inputs)); | |||||
}; | |||||
auto c = [pyc](const TensorPtr& tensor) { | |||||
return pyc(tensor->dev_tensor()); | |||||
}; | |||||
return self.graph().interpret<py::object>(f, c, inputs); | |||||
}); | |||||
py::class_<GetVarShape, std::shared_ptr<GetVarShape>, OpDef>(m, "GetVarShape") | py::class_<GetVarShape, std::shared_ptr<GetVarShape>, OpDef>(m, "GetVarShape") | ||||
.def(py::init()); | .def(py::init()); | ||||
@@ -98,7 +110,6 @@ void init_ops(py::module m) { | |||||
.def(py::init<>()) | .def(py::init<>()) | ||||
.def_readwrite("offsets", &ParamPackConcat::offsets); | .def_readwrite("offsets", &ParamPackConcat::offsets); | ||||
py::class_<BackwardGraph, std::shared_ptr<BackwardGraph>, OpDef>(m, "BackwardGraph"); | |||||
py::class_<CondTake, std::shared_ptr<CondTake>, OpDef>(m, "CondTake") | py::class_<CondTake, std::shared_ptr<CondTake>, OpDef>(m, "CondTake") | ||||
.def(py::init<>()); | .def(py::init<>()); | ||||
@@ -18,34 +18,10 @@ namespace imperative { | |||||
SmallVector<TensorPtr> | SmallVector<TensorPtr> | ||||
BackwardGraph::InternalGraph::apply( | BackwardGraph::InternalGraph::apply( | ||||
const SmallVector<TensorPtr>& inputs) const { | const SmallVector<TensorPtr>& inputs) const { | ||||
ThinHashMap<size_t, TensorPtr> node2tensor; | |||||
auto&& input_nodes = this->inputs; | |||||
mgb_assert(inputs.size() == input_nodes.size()); | |||||
for (size_t i = 0; i < inputs.size(); ++ i) { | |||||
node2tensor[input_nodes[i]] = inputs[i]; | |||||
} | |||||
for (auto &&i : constants) { | |||||
node2tensor[i.first] = i.second; | |||||
} | |||||
for (size_t i = 0; i < exprs.size(); ++ i) { | |||||
auto&& expr = exprs[i]; | |||||
SmallVector<TensorPtr> inputs; | |||||
for (auto &&in : std::get<1>(expr)) { | |||||
inputs.push_back(node2tensor.at(in)); | |||||
} | |||||
auto outputs = OpDef::apply_on_physical_tensor( | |||||
*std::get<0>(expr), inputs); | |||||
auto output_nodes = std::get<2>(expr); | |||||
mgb_assert(outputs.size() == output_nodes.size()); | |||||
for (size_t i = 0; i < outputs.size(); ++ i) { | |||||
node2tensor[output_nodes[i]] = outputs[i]; | |||||
} | |||||
} | |||||
SmallVector<TensorPtr> ret; | |||||
for (auto &&i : outputs) { | |||||
ret.push_back(node2tensor.at(i)); | |||||
} | |||||
return ret; | |||||
return interpret<TensorPtr>( | |||||
&OpDef::apply_on_physical_tensor, | |||||
[](const TensorPtr& x) {return x;}, | |||||
inputs); | |||||
} | } | ||||
SmallVector<LogicalTensorDesc> | SmallVector<LogicalTensorDesc> | ||||
@@ -40,6 +40,37 @@ public: | |||||
SmallVector<LogicalTensorDesc> | SmallVector<LogicalTensorDesc> | ||||
infer_attrs(const SmallVector<LogicalTensorDesc>& inputs) const; | infer_attrs(const SmallVector<LogicalTensorDesc>& inputs) const; | ||||
template <typename T, typename F, typename C> | |||||
SmallVector<T> interpret(F&& f, C&& c, const SmallVector<T>& inputs) const { | |||||
ThinHashMap<size_t, T> node2tensor; | |||||
auto&& input_nodes = this->inputs; | |||||
mgb_assert(inputs.size() == input_nodes.size()); | |||||
for (size_t i = 0; i < inputs.size(); ++ i) { | |||||
node2tensor[input_nodes[i]] = inputs[i]; | |||||
} | |||||
for (auto &&i : constants) { | |||||
node2tensor[i.first] = c(i.second); | |||||
} | |||||
for (size_t i = 0; i < exprs.size(); ++ i) { | |||||
auto&& expr = exprs[i]; | |||||
SmallVector<T> inputs; | |||||
for (auto &&in : std::get<1>(expr)) { | |||||
inputs.push_back(node2tensor.at(in)); | |||||
} | |||||
auto&& outputs = f(*std::get<0>(expr), std::move(inputs)); | |||||
auto&& output_nodes = std::get<2>(expr); | |||||
mgb_assert(outputs.size() == output_nodes.size()); | |||||
for (size_t i = 0; i < outputs.size(); ++ i) { | |||||
node2tensor[output_nodes[i]] = std::move(outputs[i]); | |||||
} | |||||
} | |||||
SmallVector<T> ret; | |||||
for (auto &&i : outputs) { | |||||
ret.push_back(node2tensor.at(i)); | |||||
} | |||||
return ret; | |||||
} | |||||
}; | }; | ||||
const InternalGraph& graph() const { | const InternalGraph& graph() const { | ||||