|
|
@@ -234,20 +234,21 @@ class trace: |
|
|
|
) |
|
|
|
info.data_setter.set_value(x._dev_tensor()) |
|
|
|
else: |
|
|
|
pass |
|
|
|
# if x.__class__ is not CompiledTensorProxy: |
|
|
|
# if x not in self._tensor_remaps: |
|
|
|
# raise TraceMismatchError( |
|
|
|
# "unexpected capture: trying to use an external tensor as " |
|
|
|
# "input, but that input was an internal tensor last time" |
|
|
|
# ) |
|
|
|
# else: |
|
|
|
# x = self._tensor_remaps[x] |
|
|
|
# if x._CompiledTensorProxy__handle != h: |
|
|
|
# raise TraceMismatchError( |
|
|
|
# "mis-wiring: input edge to an data flow " |
|
|
|
# "graph node is different from last time" |
|
|
|
# ) |
|
|
|
if x.mixin_handle == -1: |
|
|
|
if x._handle not in self._tensor_remaps: |
|
|
|
raise TraceMismatchError( |
|
|
|
"unexpected capture: trying to use an external tensor as " |
|
|
|
"input, but that input was an internal tensor last time" |
|
|
|
) |
|
|
|
else: |
|
|
|
x.mixin_handle = self._tensor_remaps[ |
|
|
|
x._handle |
|
|
|
]._CompiledTensorProxy__handle |
|
|
|
if x.mixin_handle != h: |
|
|
|
raise TraceMismatchError( |
|
|
|
"mis-wiring: input edge to an data flow " |
|
|
|
"graph node is different from last time" |
|
|
|
) |
|
|
|
|
|
|
|
self._pc += 1 |
|
|
|
outputs = [] |
|
|
@@ -268,14 +269,11 @@ class trace: |
|
|
|
op_, ihandles, ohandles = record |
|
|
|
assert isinstance(op_, str) and op_ == "Const" |
|
|
|
|
|
|
|
# TODO : assert on const value |
|
|
|
# eq = value == self._tinfo[ohandles[0]].bound_data.numpy() |
|
|
|
# if not isinstance(eq, bool): |
|
|
|
# eq = all(eq) |
|
|
|
# if not eq: |
|
|
|
# raise TraceMismatchError( |
|
|
|
# "const tensor violated: got a different tensor this time" |
|
|
|
# ) |
|
|
|
eq = np.all(np.atleast_1d(value) == self._tinfo[ohandles[0]].bound_data.numpy()) |
|
|
|
if not eq: |
|
|
|
raise TraceMismatchError( |
|
|
|
"const tensor violated: got a different tensor this time" |
|
|
|
) |
|
|
|
|
|
|
|
self._pc += 1 |
|
|
|
(h,) = ohandles |
|
|
@@ -750,7 +748,6 @@ class trace: |
|
|
|
dtype=info.dtype, device=dumped_device, shape=info.shape or (1,), name=k |
|
|
|
) |
|
|
|
|
|
|
|
set_tracing() |
|
|
|
for op, ihandles, ohandles in self._seq: |
|
|
|
if isinstance(op, str) and op == "Const": |
|
|
|
assert len(ihandles) == 0 |
|
|
@@ -776,7 +773,6 @@ class trace: |
|
|
|
ovars = G.apply_normal_varnode(op, *ivars) |
|
|
|
assert len(ovars) == len(ohandles) |
|
|
|
h2v.update(zip(ohandles, ovars)) |
|
|
|
unset_tracing() |
|
|
|
|
|
|
|
dest_vars = [] |
|
|
|
for i, h in enumerate(self._output_bindings): |
|
|
@@ -843,7 +839,7 @@ class trace: |
|
|
|
if x.device != info.device: |
|
|
|
raise TypeError("args[%d].device different from last time" % i) |
|
|
|
info.data_setter.set_value(x._dev_tensor()) |
|
|
|
self._tensor_remaps[x] = CompiledTensorProxy(h) |
|
|
|
self._tensor_remaps[x._handle] = CompiledTensorProxy(h) |
|
|
|
|
|
|
|
kwargs_tensors = {} |
|
|
|
for k, x in kwargs.items(): |
|
|
@@ -870,7 +866,7 @@ class trace: |
|
|
|
if x.device != info.device: |
|
|
|
raise TypeError("kwargs[%s].device different from last time" % k) |
|
|
|
info.data_setter.set_value(x._dev_tensor()) |
|
|
|
self._tensor_remaps[x] = CompiledTensorProxy(h) |
|
|
|
self._tensor_remaps[x._handle] = CompiledTensorProxy(h) |
|
|
|
|
|
|
|
def _process_outputs(self, outputs): |
|
|
|
output_names = None |
|
|
@@ -1000,8 +996,8 @@ class CompiledTensorProxy: |
|
|
|
def __del__(self): |
|
|
|
if self.__tensor.shape_read and self.__shape is not None: |
|
|
|
self.__info.shape_reader.drop_value() |
|
|
|
# if self.__tensor.value_read and self.__value is not None: |
|
|
|
# self.__info.value_reader.drop_value() |
|
|
|
if self.__tensor.value_read and self.__value is not None: |
|
|
|
self.__info.value_reader.drop_value() |
|
|
|
if self.__tensor.data_read and self.__data is not None: |
|
|
|
self.__info.data_reader.drop_value() |
|
|
|
|
|
|
@@ -1047,7 +1043,7 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): |
|
|
|
outputs = [RawTensor(o) for o in ovars] |
|
|
|
|
|
|
|
if require_links: |
|
|
|
active_trace._lazy_eval_links = (outputs[0]._varnode,) |
|
|
|
active_trace._lazy_eval_links = (G.VarNode(outputs[0]._varnode),) |
|
|
|
|
|
|
|
active_trace._lazy_eval_tensors.update([TensorWeakRef(o) for o in outputs]) |
|
|
|
return outputs |
|
|
|