|
|
@@ -20,10 +20,10 @@ import numpy as np |
|
|
|
|
|
|
|
from ..core._imperative_rt import GraphProfiler, common, put |
|
|
|
from ..core._imperative_rt.core2 import Tensor as RawTensor |
|
|
|
from ..core._imperative_rt.core2 import TensorWeakRef |
|
|
|
from ..core._imperative_rt.core2 import __make_empty_tensor as make_empty_tensor |
|
|
|
from ..core._imperative_rt.core2 import ( |
|
|
|
TensorWeakRef, |
|
|
|
apply, |
|
|
|
call_level, |
|
|
|
set_compiled, |
|
|
|
set_symbolic, |
|
|
|
set_tracing, |
|
|
@@ -86,6 +86,9 @@ class TensorInfo: |
|
|
|
__slots__ = ( |
|
|
|
# collected attributes |
|
|
|
"external", |
|
|
|
"data_read", |
|
|
|
"shape_read", |
|
|
|
"value_read", |
|
|
|
"exported", |
|
|
|
"device", |
|
|
|
"dtype", |
|
|
@@ -102,6 +105,9 @@ class TensorInfo: |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
self.exported = None |
|
|
|
self.data_read = None |
|
|
|
self.shape_read = None |
|
|
|
self.value_read = None |
|
|
|
self.bound_data = None |
|
|
|
|
|
|
|
self.data_setter = None |
|
|
@@ -154,7 +160,7 @@ class trace: |
|
|
|
self._graph_opt_level = opt_level |
|
|
|
self._symbolic_shape = symbolic_shape |
|
|
|
self._handle2tensors = {} |
|
|
|
self._handle2compiledtensors = {} |
|
|
|
self._output_handles = set() |
|
|
|
|
|
|
|
self._reset() |
|
|
|
|
|
|
@@ -244,11 +250,12 @@ class trace: |
|
|
|
# ) |
|
|
|
|
|
|
|
self._pc += 1 |
|
|
|
outputs = [] |
|
|
|
for h in ohandles: |
|
|
|
t = CompiledTensorProxy(h) |
|
|
|
t._dev_tensor() |
|
|
|
self._handle2compiledtensors[h] = t |
|
|
|
outputs = [self._handle2tensors[h] for h in ohandles] |
|
|
|
outputs += [t._CompiledTensorProxy__tensor] |
|
|
|
self._output_handles.update(ohandles) |
|
|
|
self._active_tensors.update([TensorWeakRef(o) for o in outputs]) |
|
|
|
return outputs |
|
|
|
|
|
|
@@ -347,11 +354,12 @@ class trace: |
|
|
|
self._lazy_eval_links = () |
|
|
|
|
|
|
|
def _take_escaped_tensors(self): |
|
|
|
escaped_tensors = tuple(self._active_tensors) |
|
|
|
escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors)) |
|
|
|
self._active_tensors.clear() |
|
|
|
return escaped_tensors |
|
|
|
|
|
|
|
def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links): |
|
|
|
lazy_eval_tensors = list(filter(lambda x: x() is not None, lazy_eval_tensors)) |
|
|
|
readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors] |
|
|
|
self._apply_graph_options(lazy_eval_graph) |
|
|
|
# FIXME |
|
|
@@ -393,6 +401,12 @@ class trace: |
|
|
|
if self._inputs_to_restore: |
|
|
|
for x in self._inputs_to_restore: |
|
|
|
x.mixin_handle = -1 |
|
|
|
for h, x in list(self._handle2tensors.items()): |
|
|
|
info = self._tinfo[h] |
|
|
|
info.data_read = x.data_read |
|
|
|
info.shape_read = x.shape_read |
|
|
|
info.value_read = x.value_read |
|
|
|
del self._handle2tensors[h] |
|
|
|
if self._symbolic and ( |
|
|
|
self._lazy_eval_tensors or self._lazy_eval_links |
|
|
|
): |
|
|
@@ -433,8 +447,9 @@ class trace: |
|
|
|
raise TraceMismatchError("premature end") |
|
|
|
if not self._symbolic or not self._untraced: |
|
|
|
for x in self._active_tensors: |
|
|
|
x()._dev_tensor() |
|
|
|
x().mixin_handle = -1 |
|
|
|
if x() is not None: |
|
|
|
x()._dev_tensor() |
|
|
|
x().mixin_handle = -1 |
|
|
|
|
|
|
|
try: |
|
|
|
do_enter() |
|
|
@@ -581,8 +596,7 @@ class trace: |
|
|
|
readers.append(opnode.outputs[0]) |
|
|
|
in_out_links = opnode.outputs |
|
|
|
|
|
|
|
x = self._handle2tensors[h] |
|
|
|
if x.data_read: |
|
|
|
if info.data_read: |
|
|
|
# Shape can be obtained from data so doesn't need its own |
|
|
|
# output node. On the other hand, value is read separately |
|
|
|
# to leverage eager h2d copy |
|
|
@@ -890,7 +904,7 @@ class trace: |
|
|
|
self._output_bindings.append(h) |
|
|
|
else: |
|
|
|
h = x.mixin_handle |
|
|
|
if h not in self._handle2compiledtensors: |
|
|
|
if h not in self._output_handles: |
|
|
|
raise RuntimeError("output is not computed from inputs") |
|
|
|
if h != self._output_bindings[i]: |
|
|
|
raise TraceMismatchError( |
|
|
@@ -927,8 +941,7 @@ class CompiledTensorProxy: |
|
|
|
self.__shape = None |
|
|
|
self.__data = None |
|
|
|
self.__value = None |
|
|
|
self.__tensor = active_trace._handle2tensors[handle] |
|
|
|
self.__tensor.mixin_handle = handle |
|
|
|
self.__tensor = make_empty_tensor() |
|
|
|
|
|
|
|
@property |
|
|
|
def dtype(self): |
|
|
@@ -943,19 +956,19 @@ class CompiledTensorProxy: |
|
|
|
if self._isscalar: |
|
|
|
return () |
|
|
|
if self.__shape is None: |
|
|
|
if self.__tensor.shape_read: |
|
|
|
if self.__info.shape_read: |
|
|
|
self.__shape = self.__info.shape_reader.get_value().shape |
|
|
|
elif self.__tensor.data_read: |
|
|
|
self.__shape = self.__tensor._dev_tensor().shape |
|
|
|
elif self.__info.data_read: |
|
|
|
self.__shape = self.__info._dev_tensor().shape |
|
|
|
else: |
|
|
|
raise TraceMismatchError("shape of this tensor is not read in trace") |
|
|
|
return self.__shape |
|
|
|
|
|
|
|
def numpy(self): |
|
|
|
if self.__value is None: |
|
|
|
if self.__tensor.value_read: |
|
|
|
if self.__info.value_read: |
|
|
|
self.__value = self.__info.value_reader.get_value() |
|
|
|
elif self.__tensor.data_read: |
|
|
|
elif self.__info.data_read: |
|
|
|
self.__value = self._dev_tensor().numpy() |
|
|
|
else: |
|
|
|
raise TraceMismatchError("value of this tensor is not read in trace") |
|
|
@@ -965,7 +978,7 @@ class CompiledTensorProxy: |
|
|
|
|
|
|
|
def _dev_tensor(self): |
|
|
|
if self.__data is None: |
|
|
|
if not self.__tensor.data_read: |
|
|
|
if not self.__info.data_read: |
|
|
|
raise TraceMismatchError("raw data of this tensor is not read in trace") |
|
|
|
self.__data = self.__info.data_reader.get_value() |
|
|
|
self.__tensor._reset(RawTensor(self.__data)) |
|
|
|