|
@@ -293,7 +293,9 @@ class trace: |
|
|
h = getattr(x, "_mixin_handle", -1) |
|
|
h = getattr(x, "_mixin_handle", -1) |
|
|
if h < 0 or (not self._capture_as_const and self._tinfo[h].exported): |
|
|
if h < 0 or (not self._capture_as_const and self._tinfo[h].exported): |
|
|
h, info = self._new_handle() |
|
|
h, info = self._new_handle() |
|
|
name = auto_naming.get_scope() + "." + x.c_name if x.c_name else x._name |
|
|
|
|
|
|
|
|
name = ( |
|
|
|
|
|
auto_naming.get_scope() + "." + (x.c_name if x.c_name else x._name) |
|
|
|
|
|
) |
|
|
info.name = name |
|
|
info.name = name |
|
|
info.external = True |
|
|
info.external = True |
|
|
info.device = x.device |
|
|
info.device = x.device |
|
@@ -1123,11 +1125,11 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): |
|
|
return outputs |
|
|
return outputs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def apply_const_symbolic_mode(value, dtype, device): |
|
|
|
|
|
|
|
|
def apply_const_symbolic_mode(value, dtype, device, name): |
|
|
graph = active_trace._lazy_eval_graph |
|
|
graph = active_trace._lazy_eval_graph |
|
|
# don't need to unset tracing |
|
|
# don't need to unset tracing |
|
|
# because varnode construction will ignore tracing flag |
|
|
# because varnode construction will ignore tracing flag |
|
|
ret = RawTensor(graph.make_const(value, dtype=dtype, device=device)) |
|
|
|
|
|
|
|
|
ret = RawTensor(graph.make_const(value, dtype=dtype, device=device, name=name)) |
|
|
if np.array(value).ndim == 0: |
|
|
if np.array(value).ndim == 0: |
|
|
setscalar(ret) |
|
|
setscalar(ret) |
|
|
return (ret,) |
|
|
return (ret,) |
|
@@ -1175,7 +1177,7 @@ def apply_with_tracing(op: OpDef, *args: RawTensor): |
|
|
|
|
|
|
|
|
def apply_const_with_tracing(value, dtype, device, is_const, no_cache, name): |
|
|
def apply_const_with_tracing(value, dtype, device, is_const, no_cache, name): |
|
|
if active_trace._symbolic: |
|
|
if active_trace._symbolic: |
|
|
outputs = apply_const_symbolic_mode(value, dtype, device) |
|
|
|
|
|
|
|
|
outputs = apply_const_symbolic_mode(value, dtype, device, name) |
|
|
else: |
|
|
else: |
|
|
unset_tracing() |
|
|
unset_tracing() |
|
|
outputs = (RawTensor(value, dtype, device, False, name),) |
|
|
outputs = (RawTensor(value, dtype, device, False, name),) |
|
|