|
@@ -20,30 +20,22 @@ import numpy as np |
|
|
|
|
|
|
|
|
from ..core._imperative_rt import GraphProfiler, common |
|
|
from ..core._imperative_rt import GraphProfiler, common |
|
|
from ..core._imperative_rt.core2 import Tensor as RawTensor |
|
|
from ..core._imperative_rt.core2 import Tensor as RawTensor |
|
|
from ..core._imperative_rt.core2 import TensorWeakRef |
|
|
|
|
|
from ..core._imperative_rt.core2 import __make_empty_tensor as make_empty_tensor |
|
|
|
|
|
from ..core._imperative_rt.core2 import ( |
|
|
from ..core._imperative_rt.core2 import ( |
|
|
|
|
|
TensorWeakRef, |
|
|
apply, |
|
|
apply, |
|
|
set_compiled, |
|
|
set_compiled, |
|
|
set_symbolic, |
|
|
|
|
|
set_tracing, |
|
|
set_tracing, |
|
|
skip_tracing, |
|
|
skip_tracing, |
|
|
unset_compiled, |
|
|
unset_compiled, |
|
|
unset_symbolic, |
|
|
|
|
|
unset_tracing, |
|
|
unset_tracing, |
|
|
) |
|
|
) |
|
|
from ..core._imperative_rt.ops import ( |
|
|
|
|
|
CollectiveComm, |
|
|
|
|
|
GaussianRNG, |
|
|
|
|
|
RemoteRecv, |
|
|
|
|
|
RemoteSend, |
|
|
|
|
|
UniformRNG, |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
from ..core._imperative_rt.ops import CollectiveComm, RemoteRecv, RemoteSend |
|
|
from ..core._trace_option import set_symbolic_shape |
|
|
from ..core._trace_option import set_symbolic_shape |
|
|
from ..core._wrap import device as as_device |
|
|
from ..core._wrap import device as as_device |
|
|
from ..core.ops.builtin import BackwardGraph, OpDef |
|
|
from ..core.ops.builtin import BackwardGraph, OpDef |
|
|
from ..core.ops.special import Const |
|
|
from ..core.ops.special import Const |
|
|
from ..core.tensor import megbrain_graph as G |
|
|
from ..core.tensor import megbrain_graph as G |
|
|
|
|
|
from ..core.tensor.utils import setscalar |
|
|
from .sublinear_memory_config import SublinearMemoryConfig |
|
|
from .sublinear_memory_config import SublinearMemoryConfig |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -159,7 +151,6 @@ class trace: |
|
|
self._profiler = None |
|
|
self._profiler = None |
|
|
self._graph_opt_level = opt_level |
|
|
self._graph_opt_level = opt_level |
|
|
self._symbolic_shape = symbolic_shape |
|
|
self._symbolic_shape = symbolic_shape |
|
|
self._handle2tensors = {} |
|
|
|
|
|
self._output_handles = set() |
|
|
self._output_handles = set() |
|
|
|
|
|
|
|
|
self._reset() |
|
|
self._reset() |
|
@@ -195,7 +186,7 @@ class trace: |
|
|
raise TraceMismatchError("trace should end here, but more op observed") |
|
|
raise TraceMismatchError("trace should end here, but more op observed") |
|
|
record = self._seq[self._pc] |
|
|
record = self._seq[self._pc] |
|
|
op_, ihandles, ohandles = record |
|
|
op_, ihandles, ohandles = record |
|
|
if op != op_: |
|
|
|
|
|
|
|
|
if (isinstance(op_, str) and op_ == "Const") or (op != op_): |
|
|
raise TraceMismatchError("op different from last time") |
|
|
raise TraceMismatchError("op different from last time") |
|
|
if len(ihandles) != len(args): |
|
|
if len(ihandles) != len(args): |
|
|
raise TraceMismatchError("op input size different from last time") |
|
|
raise TraceMismatchError("op input size different from last time") |
|
@@ -253,9 +244,11 @@ class trace: |
|
|
self._pc += 1 |
|
|
self._pc += 1 |
|
|
outputs = [] |
|
|
outputs = [] |
|
|
for h in ohandles: |
|
|
for h in ohandles: |
|
|
t = CompiledTensorProxy(h) |
|
|
|
|
|
t._dev_tensor() |
|
|
|
|
|
outputs += [t._CompiledTensorProxy__tensor] |
|
|
|
|
|
|
|
|
info = self._tinfo[h] |
|
|
|
|
|
y = RawTensor(info.varnode) |
|
|
|
|
|
y._compiled_info = CompiledTensorProxy(h) |
|
|
|
|
|
y.mixin_handle = h |
|
|
|
|
|
outputs += [y] |
|
|
self._output_handles.update(ohandles) |
|
|
self._output_handles.update(ohandles) |
|
|
self._active_tensors.update([TensorWeakRef(o) for o in outputs]) |
|
|
self._active_tensors.update([TensorWeakRef(o) for o in outputs]) |
|
|
return outputs |
|
|
return outputs |
|
@@ -285,7 +278,7 @@ class trace: |
|
|
for x in inputs: |
|
|
for x in inputs: |
|
|
h = getattr(x, "mixin_handle", -1) |
|
|
h = getattr(x, "mixin_handle", -1) |
|
|
if h >= 0: |
|
|
if h >= 0: |
|
|
x.data_read = True |
|
|
|
|
|
|
|
|
self._tinfo[h].data = True |
|
|
return |
|
|
return |
|
|
|
|
|
|
|
|
ihandles = [] |
|
|
ihandles = [] |
|
@@ -308,7 +301,8 @@ class trace: |
|
|
ohandles.append(h) |
|
|
ohandles.append(h) |
|
|
info.external = False |
|
|
info.external = False |
|
|
x.mixin_handle = h |
|
|
x.mixin_handle = h |
|
|
self._handle2tensors[h] = x |
|
|
|
|
|
|
|
|
x.recording = True |
|
|
|
|
|
x._trace_mixin_info = info |
|
|
|
|
|
|
|
|
self._seq.append((op, tuple(ihandles), tuple(ohandles))) |
|
|
self._seq.append((op, tuple(ihandles), tuple(ohandles))) |
|
|
self._active_tensors.update([TensorWeakRef(o) for o in outputs]) |
|
|
self._active_tensors.update([TensorWeakRef(o) for o in outputs]) |
|
@@ -318,7 +312,7 @@ class trace: |
|
|
(x,) = outputs |
|
|
(x,) = outputs |
|
|
h = getattr(x, "mixin_handle", -1) |
|
|
h = getattr(x, "mixin_handle", -1) |
|
|
if h >= 0: |
|
|
if h >= 0: |
|
|
x.data_read = True |
|
|
|
|
|
|
|
|
self._tinfo[h].data_read = True |
|
|
return |
|
|
return |
|
|
|
|
|
|
|
|
(x,) = outputs |
|
|
(x,) = outputs |
|
@@ -331,7 +325,8 @@ class trace: |
|
|
info.bound_data = x |
|
|
info.bound_data = x |
|
|
info.is_const = True |
|
|
info.is_const = True |
|
|
x.mixin_handle = h |
|
|
x.mixin_handle = h |
|
|
self._handle2tensors[h] = x |
|
|
|
|
|
|
|
|
x.recording = True |
|
|
|
|
|
x._trace_mixin_info = info |
|
|
self._seq.append(("Const", tuple(), tuple(ohandles))) |
|
|
self._seq.append(("Const", tuple(), tuple(ohandles))) |
|
|
|
|
|
|
|
|
def _set_active(self, active: bool): |
|
|
def _set_active(self, active: bool): |
|
@@ -346,7 +341,6 @@ class trace: |
|
|
|
|
|
|
|
|
def _init_trace(self, symbolic: bool): |
|
|
def _init_trace(self, symbolic: bool): |
|
|
if symbolic: |
|
|
if symbolic: |
|
|
set_symbolic() |
|
|
|
|
|
self._lazy_eval_graph = G.Graph() |
|
|
self._lazy_eval_graph = G.Graph() |
|
|
self._apply_graph_options(self._lazy_eval_graph) |
|
|
self._apply_graph_options(self._lazy_eval_graph) |
|
|
self._lazy_eval_links = () |
|
|
self._lazy_eval_links = () |
|
@@ -383,8 +377,6 @@ class trace: |
|
|
if self._untraced: |
|
|
if self._untraced: |
|
|
self._init_trace(self._symbolic) |
|
|
self._init_trace(self._symbolic) |
|
|
else: |
|
|
else: |
|
|
# disable symbolic mode |
|
|
|
|
|
unset_symbolic() |
|
|
|
|
|
set_compiled() |
|
|
set_compiled() |
|
|
if self._graph is None: |
|
|
if self._graph is None: |
|
|
self._compile() |
|
|
self._compile() |
|
@@ -394,18 +386,15 @@ class trace: |
|
|
escaped_tensors = self._take_escaped_tensors() |
|
|
escaped_tensors = self._take_escaped_tensors() |
|
|
if self._untraced: |
|
|
if self._untraced: |
|
|
for x in escaped_tensors: |
|
|
for x in escaped_tensors: |
|
|
info = self._tinfo[x().mixin_handle] |
|
|
|
|
|
x().data_read = True |
|
|
|
|
|
x().mixin_handle = -1 |
|
|
|
|
|
|
|
|
if x(): |
|
|
|
|
|
info = self._tinfo[x().mixin_handle] |
|
|
|
|
|
info.data_read = True |
|
|
|
|
|
x().mixin_handle = -1 |
|
|
|
|
|
x().recording = False |
|
|
if self._inputs_to_restore: |
|
|
if self._inputs_to_restore: |
|
|
for x in self._inputs_to_restore: |
|
|
for x in self._inputs_to_restore: |
|
|
x.mixin_handle = -1 |
|
|
x.mixin_handle = -1 |
|
|
for h, x in list(self._handle2tensors.items()): |
|
|
|
|
|
info = self._tinfo[h] |
|
|
|
|
|
info.data_read = x.data_read |
|
|
|
|
|
info.shape_read = x.shape_read |
|
|
|
|
|
info.value_read = x.value_read |
|
|
|
|
|
del self._handle2tensors[h] |
|
|
|
|
|
|
|
|
x.recording = False |
|
|
if self._symbolic and ( |
|
|
if self._symbolic and ( |
|
|
self._lazy_eval_tensors or self._lazy_eval_links |
|
|
self._lazy_eval_tensors or self._lazy_eval_links |
|
|
): |
|
|
): |
|
@@ -437,7 +426,6 @@ class trace: |
|
|
self._set_active(False) |
|
|
self._set_active(False) |
|
|
set_symbolic_shape(self._save_symbolic_shape) |
|
|
set_symbolic_shape(self._save_symbolic_shape) |
|
|
unset_compiled() |
|
|
unset_compiled() |
|
|
unset_symbolic() |
|
|
|
|
|
unset_tracing() |
|
|
unset_tracing() |
|
|
|
|
|
|
|
|
def do_exit(): |
|
|
def do_exit(): |
|
@@ -449,6 +437,7 @@ class trace: |
|
|
if x() is not None: |
|
|
if x() is not None: |
|
|
x()._dev_tensor() |
|
|
x()._dev_tensor() |
|
|
x().mixin_handle = -1 |
|
|
x().mixin_handle = -1 |
|
|
|
|
|
x().recording = False |
|
|
|
|
|
|
|
|
try: |
|
|
try: |
|
|
do_enter() |
|
|
do_enter() |
|
@@ -473,7 +462,8 @@ class trace: |
|
|
for x in self._active_tensors: |
|
|
for x in self._active_tensors: |
|
|
info = self._tinfo[x().mixin_handle] |
|
|
info = self._tinfo[x().mixin_handle] |
|
|
info.exported = True |
|
|
info.exported = True |
|
|
x().data_read = True |
|
|
|
|
|
|
|
|
info.data_read = True |
|
|
|
|
|
x()._dev_tensor() |
|
|
|
|
|
|
|
|
def _apply_graph_options(self, graph): |
|
|
def _apply_graph_options(self, graph): |
|
|
|
|
|
|
|
@@ -528,6 +518,7 @@ class trace: |
|
|
info.varnode = opnode.outputs[0] |
|
|
info.varnode = opnode.outputs[0] |
|
|
in_out_links += opnode.outputs[1:] |
|
|
in_out_links += opnode.outputs[1:] |
|
|
|
|
|
|
|
|
|
|
|
cnt_data, cnt_value, cnt_shape = 0, 0, 0 |
|
|
for op, ihandles, ohandles in self._seq: |
|
|
for op, ihandles, ohandles in self._seq: |
|
|
if isinstance(op, str) and op == "Const": |
|
|
if isinstance(op, str) and op == "Const": |
|
|
assert len(ihandles) == 0 |
|
|
assert len(ihandles) == 0 |
|
@@ -603,13 +594,16 @@ class trace: |
|
|
# Shape can be obtained from data so doesn't need its own |
|
|
# Shape can be obtained from data so doesn't need its own |
|
|
# output node. On the other hand, value is read separately |
|
|
# output node. On the other hand, value is read separately |
|
|
# to leverage eager h2d copy |
|
|
# to leverage eager h2d copy |
|
|
|
|
|
cnt_data += 1 |
|
|
info.shape_read = False |
|
|
info.shape_read = False |
|
|
opnode = info.data_reader = G.OutputNode(v, *in_out_links) |
|
|
opnode = info.data_reader = G.OutputNode(v, *in_out_links) |
|
|
add_reader(opnode) |
|
|
add_reader(opnode) |
|
|
if info.value_read: |
|
|
if info.value_read: |
|
|
|
|
|
cnt_value += 1 |
|
|
opnode = info.value_reader = G.ValueOutputNode(v, *in_out_links) |
|
|
opnode = info.value_reader = G.ValueOutputNode(v, *in_out_links) |
|
|
add_reader(opnode) |
|
|
add_reader(opnode) |
|
|
if info.shape_read: |
|
|
if info.shape_read: |
|
|
|
|
|
cnt_shape += 1 |
|
|
opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links) |
|
|
opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links) |
|
|
add_reader(opnode) |
|
|
add_reader(opnode) |
|
|
|
|
|
|
|
@@ -804,7 +798,8 @@ class trace: |
|
|
info.dtype = x.dtype |
|
|
info.dtype = x.dtype |
|
|
info.shape = x.numpy().shape |
|
|
info.shape = x.numpy().shape |
|
|
x.mixin_handle = h |
|
|
x.mixin_handle = h |
|
|
self._handle2tensors[h] = x |
|
|
|
|
|
|
|
|
x.recording = True |
|
|
|
|
|
x._trace_mixin_info = info |
|
|
self._inputs_to_restore.append(x) |
|
|
self._inputs_to_restore.append(x) |
|
|
return h |
|
|
return h |
|
|
|
|
|
|
|
@@ -940,7 +935,6 @@ class CompiledTensorProxy: |
|
|
self.__shape = None |
|
|
self.__shape = None |
|
|
self.__data = None |
|
|
self.__data = None |
|
|
self.__value = None |
|
|
self.__value = None |
|
|
self.__tensor = make_empty_tensor() |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
@property |
|
|
def dtype(self): |
|
|
def dtype(self): |
|
@@ -958,7 +952,7 @@ class CompiledTensorProxy: |
|
|
if self.__info.shape_read: |
|
|
if self.__info.shape_read: |
|
|
self.__shape = self.__info.shape_reader.get_value().shape |
|
|
self.__shape = self.__info.shape_reader.get_value().shape |
|
|
elif self.__info.data_read: |
|
|
elif self.__info.data_read: |
|
|
self.__shape = self.__info._dev_tensor().shape |
|
|
|
|
|
|
|
|
self.__shape = self._dev_tensor().shape |
|
|
else: |
|
|
else: |
|
|
raise TraceMismatchError("shape of this tensor is not read in trace") |
|
|
raise TraceMismatchError("shape of this tensor is not read in trace") |
|
|
return self.__shape |
|
|
return self.__shape |
|
@@ -980,25 +974,14 @@ class CompiledTensorProxy: |
|
|
if not self.__info.data_read: |
|
|
if not self.__info.data_read: |
|
|
raise TraceMismatchError("raw data of this tensor is not read in trace") |
|
|
raise TraceMismatchError("raw data of this tensor is not read in trace") |
|
|
self.__data = self.__info.data_reader.get_value() |
|
|
self.__data = self.__info.data_reader.get_value() |
|
|
self.__tensor._reset(RawTensor(self.__data)) |
|
|
|
|
|
self.__tensor.mixin_handle = self.__handle |
|
|
|
|
|
return self.__data |
|
|
return self.__data |
|
|
|
|
|
|
|
|
def _drop(self): |
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
def _swap_in(self): |
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
def _swap_out(self): |
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
def __del__(self): |
|
|
def __del__(self): |
|
|
if self.__tensor.shape_read and self.__shape is not None: |
|
|
|
|
|
|
|
|
if self.__info.shape_read and self.__shape is not None: |
|
|
self.__info.shape_reader.drop_value() |
|
|
self.__info.shape_reader.drop_value() |
|
|
if self.__tensor.value_read and self.__value is not None: |
|
|
|
|
|
|
|
|
if self.__info.value_read and self.__value is not None: |
|
|
self.__info.value_reader.drop_value() |
|
|
self.__info.value_reader.drop_value() |
|
|
if self.__tensor.data_read and self.__data is not None: |
|
|
|
|
|
|
|
|
if self.__info.data_read and self.__data is not None: |
|
|
self.__info.data_reader.drop_value() |
|
|
self.__info.data_reader.drop_value() |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1054,6 +1037,8 @@ def apply_const_symbolic_mode(value, dtype, device): |
|
|
# don't need to unset tracing |
|
|
# don't need to unset tracing |
|
|
# because varnode construction will ignore tracing flag |
|
|
# because varnode construction will ignore tracing flag |
|
|
ret = RawTensor(graph.make_const(value, dtype=dtype, device=device)) |
|
|
ret = RawTensor(graph.make_const(value, dtype=dtype, device=device)) |
|
|
|
|
|
if np.array(value).ndim == 0: |
|
|
|
|
|
setscalar(ret) |
|
|
active_trace._lazy_eval_tensors.add(TensorWeakRef(ret)) |
|
|
active_trace._lazy_eval_tensors.add(TensorWeakRef(ret)) |
|
|
return (ret,) |
|
|
return (ret,) |
|
|
|
|
|
|
|
@@ -1084,7 +1069,6 @@ def apply_const_compiled_mode(value, dtype, device, is_const, no_cache): |
|
|
return active_trace._apply_const(value, dtype, device) |
|
|
return active_trace._apply_const(value, dtype, device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# this hook injects TraceMixin |
|
|
|
|
|
def apply_with_tracing(op: OpDef, *args: RawTensor): |
|
|
def apply_with_tracing(op: OpDef, *args: RawTensor): |
|
|
if active_trace._symbolic: |
|
|
if active_trace._symbolic: |
|
|
outputs = apply_symbolic_mode(op, *args) |
|
|
outputs = apply_symbolic_mode(op, *args) |
|
|