GitOrigin-RevId: fd1265c661
release-1.4
@@ -170,9 +170,9 @@ class trace: | |||
self._graph = None | |||
self._need_reset_nodes = None | |||
self._lazy_eval_graph = None | |||
self._lazy_eval_tensors = {} | |||
self._lazy_eval_tensors = set() | |||
self._lazy_eval_links = None | |||
self._active_tensors = {} | |||
self._active_tensors = set() | |||
self._tensor_remaps = None | |||
self._inputs_to_restore = None | |||
self._arg_bindings = None | |||
@@ -258,7 +258,7 @@ class trace: | |||
y._compiled_info = CompiledTensorProxy(h) | |||
y._mixin_handle = h | |||
outputs += [y] | |||
self._active_tensors[h] = TensorWeakRef(y) | |||
self._active_tensors.add(TensorWeakRef(y)) | |||
self._output_handles.update(ohandles) | |||
return outputs | |||
@@ -318,9 +318,9 @@ class trace: | |||
x._mixin_handle = h | |||
x._recording = True | |||
x._trace_mixin_info = info | |||
self._active_tensors[h] = TensorWeakRef(x) | |||
self._active_tensors.add(TensorWeakRef(x)) | |||
if self._symbolic: | |||
self._lazy_eval_tensors[h] = TensorWeakRef(x) | |||
self._lazy_eval_tensors.add(TensorWeakRef(x)) | |||
self._seq.append((op, tuple(ihandles), tuple(ohandles))) | |||
@@ -345,7 +345,7 @@ class trace: | |||
x._recording = True | |||
x._trace_mixin_info = info | |||
if self._symbolic: | |||
self._lazy_eval_tensors[h] = TensorWeakRef(x) | |||
self._lazy_eval_tensors.add(TensorWeakRef(x)) | |||
self._seq.append(("Const", tuple(), tuple(ohandles))) | |||
def _set_active(self, active: bool): | |||
@@ -365,17 +365,14 @@ class trace: | |||
self._lazy_eval_links = () | |||
def _take_escaped_tensors(self): | |||
escaped_tensors = tuple( | |||
filter(lambda x: x() is not None, self._active_tensors.values()) | |||
) | |||
escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors)) | |||
self._active_tensors.clear() | |||
return escaped_tensors | |||
def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links): | |||
lazy_eval_tensors = list( | |||
filter(lambda x: x() is not None, lazy_eval_tensors.values()) | |||
) | |||
readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors] | |||
lazy_eval_tensors = [x() for x in lazy_eval_tensors] | |||
lazy_eval_tensors = [x for x in lazy_eval_tensors if x is not None] | |||
readers = [G.OutputNode(x._varnode).outputs[0] for x in lazy_eval_tensors] | |||
self._apply_graph_options(lazy_eval_graph) | |||
lazy_eval_graph.options.graph_opt_level = self._graph_opt_level | |||
lazy_eval_graph._set_priority_to_id([*lazy_eval_links, *readers]) | |||
@@ -383,8 +380,8 @@ class trace: | |||
lazy_eval_graph() | |||
for r, x in zip(readers, lazy_eval_tensors): | |||
# get values from lazy_eval_graph and assign to lazy_eval tensor | |||
x()._handle = RawTensor(r.op.get_value())._handle | |||
x()._reset_varnode() | |||
x._handle = RawTensor(r.op.get_value())._handle | |||
x._reset_varnode() | |||
@contextlib.contextmanager | |||
def _setup(self): | |||
@@ -454,13 +451,14 @@ class trace: | |||
raise TraceMismatchError("premature end") | |||
if not self._symbolic or not self._untraced: | |||
# reset output tensors | |||
for x in self._active_tensors.values(): | |||
if x() is not None: | |||
x()._dev_tensor() | |||
x()._reset_varnode() | |||
x()._mixin_handle = -1 | |||
x()._recording = False | |||
x()._trace_mixin_info = None | |||
for x in self._active_tensors.copy(): | |||
strong_x = x() | |||
if strong_x is not None: | |||
strong_x._dev_tensor() | |||
strong_x._reset_varnode() | |||
strong_x._mixin_handle = -1 | |||
strong_x._recording = False | |||
strong_x._trace_mixin_info = None | |||
try: | |||
do_enter() | |||
@@ -482,15 +480,17 @@ class trace: | |||
if self._untraced: | |||
# conditionally reading a compiled tensor in excluded region | |||
# is permitted, so we have to assume every tensor might be read | |||
for x in self._active_tensors.values(): | |||
if x(): | |||
info = self._tinfo[x()._mixin_handle] | |||
for x in self._active_tensors: | |||
strong_x = x() | |||
if strong_x: | |||
info = self._tinfo[strong_x._mixin_handle] | |||
info.exported = True | |||
info.data_read = True | |||
else: | |||
for x in self._active_tensors.values(): | |||
if x(): | |||
x()._dev_tensor() | |||
for x in self._active_tensors: | |||
strong_x = x() | |||
if strong_x: | |||
strong_x._dev_tensor() | |||
def _apply_graph_options(self, graph): | |||
@@ -520,7 +520,6 @@ class trace: | |||
graph = self._graph = G.Graph() | |||
graph.options.async_exec_level = 0b100 | |||
self._apply_graph_options(graph) | |||
# graph.options.graph_opt_level = 0 | |||
need_reset_nodes = self._need_reset_nodes = [] | |||
# links enforce ordering of I/O nodes | |||
in_out_links = () | |||
@@ -563,7 +562,7 @@ class trace: | |||
if not hasattr(info, "varnode"): | |||
assert info.external | |||
if info.bound_data: | |||
if hasattr(info, "is_const") and info.is_const: | |||
if getattr(info, "is_const", False): | |||
info.varnode = graph.make_const( | |||
info.bound_data.numpy(), | |||
info.bound_data.dtype, | |||
@@ -635,30 +634,12 @@ class trace: | |||
opnode.reset() | |||
def __call__(self, *args, **kwargs): | |||
if is_tracing(): | |||
return self.__wrapped__(*args, **kwargs) | |||
with self._setup(): | |||
if self._capture_as_const: | |||
self._process_inputs(*args, **kwargs) | |||
outputs = self.__wrapped__(*args, **kwargs) | |||
if self._capture_as_const: | |||
self._process_outputs(outputs) | |||
# outputs could be None | |||
if outputs is not None: | |||
list_outputs = outputs | |||
if isinstance(outputs, collections.abc.Mapping): | |||
_, list_outputs = zip(*sorted(outputs.items())) | |||
elif not isinstance(outputs, collections.abc.Sequence): | |||
list_outputs = (outputs,) | |||
for o in list_outputs: | |||
# if outputs are copied, then use the newest info in trace data structure | |||
if o._copied: | |||
self._active_tensors[o._mixin_handle] = TensorWeakRef(o) | |||
if self._untraced and self._symbolic: | |||
self._lazy_eval_tensors[o._mixin_handle] = TensorWeakRef(o) | |||
return outputs | |||
def dump( | |||
@@ -9,11 +9,12 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma GCC diagnostic ignored "-Wmissing-field-initializers" | |||
#include "./grad.h" | |||
#include "megbrain/imperative/proxy_graph_detail.h" | |||
#include "megbrain/imperative/backward_graph_opt.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/imperative/ops/utility.h" | |||
#include "megbrain/utils/mempool.h" | |||
#include "range/v3/all.hpp" | |||
@@ -434,7 +435,8 @@ apply_result_t apply_grad(ApplyContext& ctx) { | |||
if (backward.output_requires_grad(i)) { | |||
if (backward.output_captured(i)) { | |||
// avoid reference cycle [Tensor <-> GradFn] | |||
outputs[i] = outputs[i]->copy(); | |||
static std::shared_ptr<OpDef> op = std::shared_ptr<OpDef>(new FastpathCopy()); | |||
outputs[i] = python::apply(op, outputs[i])[0]; | |||
} | |||
// populate grad info of output tensor | |||
auto& grad_info = outputs[i]->m_grad_info; | |||
@@ -12,6 +12,7 @@ | |||
#pragma once | |||
#include "./tensor.h" | |||
#include "megbrain/imperative/ops/utility.h" | |||
#include <megbrain/utils/small_vector.h> | |||
#include <memory> | |||
@@ -221,6 +221,21 @@ apply_result_t removeAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& ma | |||
return apply(ctx); | |||
} | |||
apply_result_t fastpathcopy_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { | |||
mgb_assert(ctx.nargs == 1); | |||
maker.output_size(1).output_captured(0, false); | |||
maker.backward([](BackwardContext&, Tensor*const* grads, size_t ngrads) { | |||
mgb_assert(ngrads == 1); | |||
Tensor* grad = grads[0]; | |||
apply_result_t ret(1); | |||
if (grad) { | |||
ret[0] = grad->shared_from_this(); | |||
} | |||
return ret; | |||
}); | |||
return apply(ctx); | |||
} | |||
struct Init { | |||
Init() { | |||
auto& reg = grad_rule_registry(); | |||
@@ -231,6 +246,7 @@ struct Init { | |||
reg.emplace(Reduce::typeinfo(), reduce_grad_rule); | |||
reg.emplace(AddAxis::typeinfo(), addAxis_grad_rule); | |||
reg.emplace(RemoveAxis::typeinfo(), removeAxis_grad_rule); | |||
reg.emplace(FastpathCopy::typeinfo(), fastpathcopy_grad_rule); | |||
} | |||
} _; | |||
@@ -23,6 +23,7 @@ | |||
#include "./common.h" | |||
#include "./ops.h" | |||
#include "megbrain/gopt/inference.h" | |||
#include "megbrain/imperative/ops/utility.h" | |||
namespace py = pybind11; | |||
@@ -118,9 +118,18 @@ apply_result_t apply(ApplyContext& ctx) { | |||
handles[i] = ctx.args[i]->m_handle.get(); | |||
} | |||
apply_result_t outputs; | |||
// fast copy without really applying | |||
if (ctx.op->same_type<FastpathCopy>()) { | |||
mgb_assert(ctx.nargs == 1); | |||
outputs.reserve(ctx.nargs); | |||
outputs.emplace_back(std::make_shared<Tensor>(ctx.args[0]->m_handle)); | |||
return outputs; | |||
} | |||
auto output_handles = interpreter_for_py->apply_op(ctx.op, handles); | |||
apply_result_t outputs; | |||
outputs.reserve(output_handles.size()); | |||
for (auto h : output_handles) { | |||
outputs.emplace_back(std::make_shared<Tensor>(h)); | |||
@@ -303,11 +312,6 @@ REGISTE_TENSORWRAPPER_FUNC(bool, recording) | |||
#undef REGISTE_TENSORWRAPPER_FUNC | |||
PyObject* TensorWrapper::copied() { | |||
return py::cast(m_tensor->m_trace_info.copied).release().ptr(); | |||
} | |||
#define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \ | |||
PyObject* TensorWrapper::member() { \ | |||
if (m_tensor->m_trace_info.member) { \ | |||
@@ -841,7 +845,6 @@ void init_tensor(py::module m) { | |||
.def<&TensorWrapper::reset_varnode>("_reset_varnode") | |||
.def<&TensorWrapper::_use_cnt>("_use_cnt") | |||
.def_getset<&TensorWrapper::varnode>("_varnode") | |||
.def_getset<&TensorWrapper::copied>("_copied") | |||
.def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("_mixin_handle") | |||
.def_getset<&TensorWrapper::recording, &TensorWrapper::set_recording>("_recording") | |||
.def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle") | |||
@@ -10,6 +10,7 @@ | |||
*/ | |||
#pragma once | |||
#pragma GCC diagnostic ignored "-Wmissing-field-initializers" | |||
#include <variant> | |||
@@ -35,7 +35,7 @@ apply_result_t apply_trace(ApplyContext& ctx) { | |||
// assumption: python function always returns PyList | |||
auto tup = py::reinterpret_borrow<py::list>(ret); | |||
for (auto i = 0; i < tup.size(); i++) { | |||
for (size_t i = 0; i < tup.size(); i++) { | |||
auto pitem = tup[i].cast<cg::VarNode*>(); | |||
outputs.emplace_back(std::make_shared<Tensor>(pitem)); | |||
} | |||
@@ -17,7 +17,6 @@ namespace mgb::imperative::python { | |||
struct TraceInfo { | |||
int64_t mixin_handle = -1; | |||
bool recording = false; | |||
bool copied = false; | |||
// refer to CompiledTensorProxy in tracing.py, works from second trace step | |||
PyObject* compiled_info = nullptr; | |||
@@ -35,7 +34,6 @@ struct TraceInfo { | |||
compiled_info = that.compiled_info; | |||
Py_XINCREF(compiled_info); | |||
copied = true; | |||
return *this; | |||
} | |||
@@ -18,4 +18,18 @@ namespace mgb::imperative { | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(GenericPyOp); | |||
namespace { namespace fastpathcopy { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
return inputs; | |||
} | |||
OP_TRAIT_REG(FastpathCopy,FastpathCopy) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // fastpathcopy | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(FastpathCopy); | |||
} // namespace mgb::imperative |
@@ -35,4 +35,18 @@ struct GenericPyOp final : OpDefImplBase<GenericPyOp> { | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
}; | |||
struct FastpathCopy final : OpDefImplBase<FastpathCopy> { | |||
FastpathCopy() = default; | |||
size_t hash() const override { | |||
return mgb::hash(this->dyn_typeinfo()); | |||
} | |||
bool is_same_st(const Hashable& rhs) const override { | |||
return this->dyn_typeinfo() == rhs.dyn_typeinfo(); | |||
} | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
}; | |||
} // namespace mgb::imperative |