GitOrigin-RevId: fd1265c661
release-1.4
@@ -170,9 +170,9 @@ class trace: | |||||
self._graph = None | self._graph = None | ||||
self._need_reset_nodes = None | self._need_reset_nodes = None | ||||
self._lazy_eval_graph = None | self._lazy_eval_graph = None | ||||
self._lazy_eval_tensors = {} | |||||
self._lazy_eval_tensors = set() | |||||
self._lazy_eval_links = None | self._lazy_eval_links = None | ||||
self._active_tensors = {} | |||||
self._active_tensors = set() | |||||
self._tensor_remaps = None | self._tensor_remaps = None | ||||
self._inputs_to_restore = None | self._inputs_to_restore = None | ||||
self._arg_bindings = None | self._arg_bindings = None | ||||
@@ -258,7 +258,7 @@ class trace: | |||||
y._compiled_info = CompiledTensorProxy(h) | y._compiled_info = CompiledTensorProxy(h) | ||||
y._mixin_handle = h | y._mixin_handle = h | ||||
outputs += [y] | outputs += [y] | ||||
self._active_tensors[h] = TensorWeakRef(y) | |||||
self._active_tensors.add(TensorWeakRef(y)) | |||||
self._output_handles.update(ohandles) | self._output_handles.update(ohandles) | ||||
return outputs | return outputs | ||||
@@ -318,9 +318,9 @@ class trace: | |||||
x._mixin_handle = h | x._mixin_handle = h | ||||
x._recording = True | x._recording = True | ||||
x._trace_mixin_info = info | x._trace_mixin_info = info | ||||
self._active_tensors[h] = TensorWeakRef(x) | |||||
self._active_tensors.add(TensorWeakRef(x)) | |||||
if self._symbolic: | if self._symbolic: | ||||
self._lazy_eval_tensors[h] = TensorWeakRef(x) | |||||
self._lazy_eval_tensors.add(TensorWeakRef(x)) | |||||
self._seq.append((op, tuple(ihandles), tuple(ohandles))) | self._seq.append((op, tuple(ihandles), tuple(ohandles))) | ||||
@@ -345,7 +345,7 @@ class trace: | |||||
x._recording = True | x._recording = True | ||||
x._trace_mixin_info = info | x._trace_mixin_info = info | ||||
if self._symbolic: | if self._symbolic: | ||||
self._lazy_eval_tensors[h] = TensorWeakRef(x) | |||||
self._lazy_eval_tensors.add(TensorWeakRef(x)) | |||||
self._seq.append(("Const", tuple(), tuple(ohandles))) | self._seq.append(("Const", tuple(), tuple(ohandles))) | ||||
def _set_active(self, active: bool): | def _set_active(self, active: bool): | ||||
@@ -365,17 +365,14 @@ class trace: | |||||
self._lazy_eval_links = () | self._lazy_eval_links = () | ||||
def _take_escaped_tensors(self): | def _take_escaped_tensors(self): | ||||
escaped_tensors = tuple( | |||||
filter(lambda x: x() is not None, self._active_tensors.values()) | |||||
) | |||||
escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors)) | |||||
self._active_tensors.clear() | self._active_tensors.clear() | ||||
return escaped_tensors | return escaped_tensors | ||||
def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links): | def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links): | ||||
lazy_eval_tensors = list( | |||||
filter(lambda x: x() is not None, lazy_eval_tensors.values()) | |||||
) | |||||
readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors] | |||||
lazy_eval_tensors = [x() for x in lazy_eval_tensors] | |||||
lazy_eval_tensors = [x for x in lazy_eval_tensors if x is not None] | |||||
readers = [G.OutputNode(x._varnode).outputs[0] for x in lazy_eval_tensors] | |||||
self._apply_graph_options(lazy_eval_graph) | self._apply_graph_options(lazy_eval_graph) | ||||
lazy_eval_graph.options.graph_opt_level = self._graph_opt_level | lazy_eval_graph.options.graph_opt_level = self._graph_opt_level | ||||
lazy_eval_graph._set_priority_to_id([*lazy_eval_links, *readers]) | lazy_eval_graph._set_priority_to_id([*lazy_eval_links, *readers]) | ||||
@@ -383,8 +380,8 @@ class trace: | |||||
lazy_eval_graph() | lazy_eval_graph() | ||||
for r, x in zip(readers, lazy_eval_tensors): | for r, x in zip(readers, lazy_eval_tensors): | ||||
# get values from lazy_eval_graph and assign to lazy_eval tensor | # get values from lazy_eval_graph and assign to lazy_eval tensor | ||||
x()._handle = RawTensor(r.op.get_value())._handle | |||||
x()._reset_varnode() | |||||
x._handle = RawTensor(r.op.get_value())._handle | |||||
x._reset_varnode() | |||||
@contextlib.contextmanager | @contextlib.contextmanager | ||||
def _setup(self): | def _setup(self): | ||||
@@ -454,13 +451,14 @@ class trace: | |||||
raise TraceMismatchError("premature end") | raise TraceMismatchError("premature end") | ||||
if not self._symbolic or not self._untraced: | if not self._symbolic or not self._untraced: | ||||
# reset output tensors | # reset output tensors | ||||
for x in self._active_tensors.values(): | |||||
if x() is not None: | |||||
x()._dev_tensor() | |||||
x()._reset_varnode() | |||||
x()._mixin_handle = -1 | |||||
x()._recording = False | |||||
x()._trace_mixin_info = None | |||||
for x in self._active_tensors.copy(): | |||||
strong_x = x() | |||||
if strong_x is not None: | |||||
strong_x._dev_tensor() | |||||
strong_x._reset_varnode() | |||||
strong_x._mixin_handle = -1 | |||||
strong_x._recording = False | |||||
strong_x._trace_mixin_info = None | |||||
try: | try: | ||||
do_enter() | do_enter() | ||||
@@ -482,15 +480,17 @@ class trace: | |||||
if self._untraced: | if self._untraced: | ||||
# conditionally reading a compiled tensor in excluded region | # conditionally reading a compiled tensor in excluded region | ||||
# is permitted, so we have to assume every tensor might be read | # is permitted, so we have to assume every tensor might be read | ||||
for x in self._active_tensors.values(): | |||||
if x(): | |||||
info = self._tinfo[x()._mixin_handle] | |||||
for x in self._active_tensors: | |||||
strong_x = x() | |||||
if strong_x: | |||||
info = self._tinfo[strong_x._mixin_handle] | |||||
info.exported = True | info.exported = True | ||||
info.data_read = True | info.data_read = True | ||||
else: | else: | ||||
for x in self._active_tensors.values(): | |||||
if x(): | |||||
x()._dev_tensor() | |||||
for x in self._active_tensors: | |||||
strong_x = x() | |||||
if strong_x: | |||||
strong_x._dev_tensor() | |||||
def _apply_graph_options(self, graph): | def _apply_graph_options(self, graph): | ||||
@@ -520,7 +520,6 @@ class trace: | |||||
graph = self._graph = G.Graph() | graph = self._graph = G.Graph() | ||||
graph.options.async_exec_level = 0b100 | graph.options.async_exec_level = 0b100 | ||||
self._apply_graph_options(graph) | self._apply_graph_options(graph) | ||||
# graph.options.graph_opt_level = 0 | |||||
need_reset_nodes = self._need_reset_nodes = [] | need_reset_nodes = self._need_reset_nodes = [] | ||||
# links enforce ordering of I/O nodes | # links enforce ordering of I/O nodes | ||||
in_out_links = () | in_out_links = () | ||||
@@ -563,7 +562,7 @@ class trace: | |||||
if not hasattr(info, "varnode"): | if not hasattr(info, "varnode"): | ||||
assert info.external | assert info.external | ||||
if info.bound_data: | if info.bound_data: | ||||
if hasattr(info, "is_const") and info.is_const: | |||||
if getattr(info, "is_const", False): | |||||
info.varnode = graph.make_const( | info.varnode = graph.make_const( | ||||
info.bound_data.numpy(), | info.bound_data.numpy(), | ||||
info.bound_data.dtype, | info.bound_data.dtype, | ||||
@@ -635,30 +634,12 @@ class trace: | |||||
opnode.reset() | opnode.reset() | ||||
def __call__(self, *args, **kwargs): | def __call__(self, *args, **kwargs): | ||||
if is_tracing(): | |||||
return self.__wrapped__(*args, **kwargs) | |||||
with self._setup(): | with self._setup(): | ||||
if self._capture_as_const: | if self._capture_as_const: | ||||
self._process_inputs(*args, **kwargs) | self._process_inputs(*args, **kwargs) | ||||
outputs = self.__wrapped__(*args, **kwargs) | outputs = self.__wrapped__(*args, **kwargs) | ||||
if self._capture_as_const: | if self._capture_as_const: | ||||
self._process_outputs(outputs) | self._process_outputs(outputs) | ||||
# outputs could be None | |||||
if outputs is not None: | |||||
list_outputs = outputs | |||||
if isinstance(outputs, collections.abc.Mapping): | |||||
_, list_outputs = zip(*sorted(outputs.items())) | |||||
elif not isinstance(outputs, collections.abc.Sequence): | |||||
list_outputs = (outputs,) | |||||
for o in list_outputs: | |||||
# if outputs are copied, then use the newest info in trace data structure | |||||
if o._copied: | |||||
self._active_tensors[o._mixin_handle] = TensorWeakRef(o) | |||||
if self._untraced and self._symbolic: | |||||
self._lazy_eval_tensors[o._mixin_handle] = TensorWeakRef(o) | |||||
return outputs | return outputs | ||||
def dump( | def dump( | ||||
@@ -9,11 +9,12 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#pragma GCC diagnostic ignored "-Wmissing-field-initializers" | |||||
#include "./grad.h" | #include "./grad.h" | ||||
#include "megbrain/imperative/proxy_graph_detail.h" | #include "megbrain/imperative/proxy_graph_detail.h" | ||||
#include "megbrain/imperative/backward_graph_opt.h" | #include "megbrain/imperative/backward_graph_opt.h" | ||||
#include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
#include "megbrain/imperative/ops/utility.h" | |||||
#include "megbrain/utils/mempool.h" | #include "megbrain/utils/mempool.h" | ||||
#include "range/v3/all.hpp" | #include "range/v3/all.hpp" | ||||
@@ -434,7 +435,8 @@ apply_result_t apply_grad(ApplyContext& ctx) { | |||||
if (backward.output_requires_grad(i)) { | if (backward.output_requires_grad(i)) { | ||||
if (backward.output_captured(i)) { | if (backward.output_captured(i)) { | ||||
// avoid reference cycle [Tensor <-> GradFn] | // avoid reference cycle [Tensor <-> GradFn] | ||||
outputs[i] = outputs[i]->copy(); | |||||
static std::shared_ptr<OpDef> op = std::shared_ptr<OpDef>(new FastpathCopy()); | |||||
outputs[i] = python::apply(op, outputs[i])[0]; | |||||
} | } | ||||
// populate grad info of output tensor | // populate grad info of output tensor | ||||
auto& grad_info = outputs[i]->m_grad_info; | auto& grad_info = outputs[i]->m_grad_info; | ||||
@@ -12,6 +12,7 @@ | |||||
#pragma once | #pragma once | ||||
#include "./tensor.h" | #include "./tensor.h" | ||||
#include "megbrain/imperative/ops/utility.h" | |||||
#include <megbrain/utils/small_vector.h> | #include <megbrain/utils/small_vector.h> | ||||
#include <memory> | #include <memory> | ||||
@@ -221,6 +221,21 @@ apply_result_t removeAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& ma | |||||
return apply(ctx); | return apply(ctx); | ||||
} | } | ||||
apply_result_t fastpathcopy_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { | |||||
mgb_assert(ctx.nargs == 1); | |||||
maker.output_size(1).output_captured(0, false); | |||||
maker.backward([](BackwardContext&, Tensor*const* grads, size_t ngrads) { | |||||
mgb_assert(ngrads == 1); | |||||
Tensor* grad = grads[0]; | |||||
apply_result_t ret(1); | |||||
if (grad) { | |||||
ret[0] = grad->shared_from_this(); | |||||
} | |||||
return ret; | |||||
}); | |||||
return apply(ctx); | |||||
} | |||||
struct Init { | struct Init { | ||||
Init() { | Init() { | ||||
auto& reg = grad_rule_registry(); | auto& reg = grad_rule_registry(); | ||||
@@ -231,6 +246,7 @@ struct Init { | |||||
reg.emplace(Reduce::typeinfo(), reduce_grad_rule); | reg.emplace(Reduce::typeinfo(), reduce_grad_rule); | ||||
reg.emplace(AddAxis::typeinfo(), addAxis_grad_rule); | reg.emplace(AddAxis::typeinfo(), addAxis_grad_rule); | ||||
reg.emplace(RemoveAxis::typeinfo(), removeAxis_grad_rule); | reg.emplace(RemoveAxis::typeinfo(), removeAxis_grad_rule); | ||||
reg.emplace(FastpathCopy::typeinfo(), fastpathcopy_grad_rule); | |||||
} | } | ||||
} _; | } _; | ||||
@@ -23,6 +23,7 @@ | |||||
#include "./common.h" | #include "./common.h" | ||||
#include "./ops.h" | #include "./ops.h" | ||||
#include "megbrain/gopt/inference.h" | #include "megbrain/gopt/inference.h" | ||||
#include "megbrain/imperative/ops/utility.h" | |||||
namespace py = pybind11; | namespace py = pybind11; | ||||
@@ -118,9 +118,18 @@ apply_result_t apply(ApplyContext& ctx) { | |||||
handles[i] = ctx.args[i]->m_handle.get(); | handles[i] = ctx.args[i]->m_handle.get(); | ||||
} | } | ||||
apply_result_t outputs; | |||||
// fast copy without really applying | |||||
if (ctx.op->same_type<FastpathCopy>()) { | |||||
mgb_assert(ctx.nargs == 1); | |||||
outputs.reserve(ctx.nargs); | |||||
outputs.emplace_back(std::make_shared<Tensor>(ctx.args[0]->m_handle)); | |||||
return outputs; | |||||
} | |||||
auto output_handles = interpreter_for_py->apply_op(ctx.op, handles); | auto output_handles = interpreter_for_py->apply_op(ctx.op, handles); | ||||
apply_result_t outputs; | |||||
outputs.reserve(output_handles.size()); | outputs.reserve(output_handles.size()); | ||||
for (auto h : output_handles) { | for (auto h : output_handles) { | ||||
outputs.emplace_back(std::make_shared<Tensor>(h)); | outputs.emplace_back(std::make_shared<Tensor>(h)); | ||||
@@ -303,11 +312,6 @@ REGISTE_TENSORWRAPPER_FUNC(bool, recording) | |||||
#undef REGISTE_TENSORWRAPPER_FUNC | #undef REGISTE_TENSORWRAPPER_FUNC | ||||
PyObject* TensorWrapper::copied() { | |||||
return py::cast(m_tensor->m_trace_info.copied).release().ptr(); | |||||
} | |||||
#define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \ | #define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \ | ||||
PyObject* TensorWrapper::member() { \ | PyObject* TensorWrapper::member() { \ | ||||
if (m_tensor->m_trace_info.member) { \ | if (m_tensor->m_trace_info.member) { \ | ||||
@@ -841,7 +845,6 @@ void init_tensor(py::module m) { | |||||
.def<&TensorWrapper::reset_varnode>("_reset_varnode") | .def<&TensorWrapper::reset_varnode>("_reset_varnode") | ||||
.def<&TensorWrapper::_use_cnt>("_use_cnt") | .def<&TensorWrapper::_use_cnt>("_use_cnt") | ||||
.def_getset<&TensorWrapper::varnode>("_varnode") | .def_getset<&TensorWrapper::varnode>("_varnode") | ||||
.def_getset<&TensorWrapper::copied>("_copied") | |||||
.def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("_mixin_handle") | .def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("_mixin_handle") | ||||
.def_getset<&TensorWrapper::recording, &TensorWrapper::set_recording>("_recording") | .def_getset<&TensorWrapper::recording, &TensorWrapper::set_recording>("_recording") | ||||
.def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle") | .def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle") | ||||
@@ -10,6 +10,7 @@ | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#pragma GCC diagnostic ignored "-Wmissing-field-initializers" | |||||
#include <variant> | #include <variant> | ||||
@@ -35,7 +35,7 @@ apply_result_t apply_trace(ApplyContext& ctx) { | |||||
// assumption: python function always returns PyList | // assumption: python function always returns PyList | ||||
auto tup = py::reinterpret_borrow<py::list>(ret); | auto tup = py::reinterpret_borrow<py::list>(ret); | ||||
for (auto i = 0; i < tup.size(); i++) { | |||||
for (size_t i = 0; i < tup.size(); i++) { | |||||
auto pitem = tup[i].cast<cg::VarNode*>(); | auto pitem = tup[i].cast<cg::VarNode*>(); | ||||
outputs.emplace_back(std::make_shared<Tensor>(pitem)); | outputs.emplace_back(std::make_shared<Tensor>(pitem)); | ||||
} | } | ||||
@@ -17,7 +17,6 @@ namespace mgb::imperative::python { | |||||
struct TraceInfo { | struct TraceInfo { | ||||
int64_t mixin_handle = -1; | int64_t mixin_handle = -1; | ||||
bool recording = false; | bool recording = false; | ||||
bool copied = false; | |||||
// refer to CompiledTensorProxy in tracing.py, works from second trace step | // refer to CompiledTensorProxy in tracing.py, works from second trace step | ||||
PyObject* compiled_info = nullptr; | PyObject* compiled_info = nullptr; | ||||
@@ -35,7 +34,6 @@ struct TraceInfo { | |||||
compiled_info = that.compiled_info; | compiled_info = that.compiled_info; | ||||
Py_XINCREF(compiled_info); | Py_XINCREF(compiled_info); | ||||
copied = true; | |||||
return *this; | return *this; | ||||
} | } | ||||
@@ -18,4 +18,18 @@ namespace mgb::imperative { | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(GenericPyOp); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(GenericPyOp); | ||||
namespace { namespace fastpathcopy { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
return inputs; | |||||
} | |||||
OP_TRAIT_REG(FastpathCopy,FastpathCopy) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // fastpathcopy | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(FastpathCopy); | |||||
} // namespace mgb::imperative | } // namespace mgb::imperative |
@@ -35,4 +35,18 @@ struct GenericPyOp final : OpDefImplBase<GenericPyOp> { | |||||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | MGB_DYN_TYPE_OBJ_FINAL_DECL; | ||||
}; | }; | ||||
struct FastpathCopy final : OpDefImplBase<FastpathCopy> { | |||||
FastpathCopy() = default; | |||||
size_t hash() const override { | |||||
return mgb::hash(this->dyn_typeinfo()); | |||||
} | |||||
bool is_same_st(const Hashable& rhs) const override { | |||||
return this->dyn_typeinfo() == rhs.dyn_typeinfo(); | |||||
} | |||||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||||
}; | |||||
} // namespace mgb::imperative | } // namespace mgb::imperative |