|
|
@@ -11,6 +11,8 @@ import threading |
|
|
|
import weakref |
|
|
|
from concurrent.futures import Future, ThreadPoolExecutor |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from .. import _imperative_rt |
|
|
|
from .._imperative_rt.ops import BackwardGraph |
|
|
|
from .._wrap import device as as_device |
|
|
@@ -32,6 +34,8 @@ class Graph(_imperative_rt.ComputingGraph): |
|
|
|
wrapper, cache = VarNode, self._var_cache |
|
|
|
elif type(obj) is _imperative_rt.OperatorNode: |
|
|
|
wrapper, cache = OpNode, self._op_cache |
|
|
|
else: |
|
|
|
raise TypeError(type(obj)) |
|
|
|
if obj not in cache: |
|
|
|
cache[obj] = wrapper(obj) |
|
|
|
return cache[obj] |
|
|
@@ -62,6 +66,11 @@ class Graph(_imperative_rt.ComputingGraph): |
|
|
|
assert dtype is None and device is None |
|
|
|
return self._wrap(_imperative_rt.make_shared(self, data)) |
|
|
|
else: |
|
|
|
data = np.asarray(data, dtype=dtype) |
|
|
|
if data.dtype == np.float64: |
|
|
|
data = data.astype(np.float32) |
|
|
|
elif data.dtype == np.int64: |
|
|
|
data = data.astype(np.int32) |
|
|
|
device = as_device(device).to_c() |
|
|
|
return self._wrap(_imperative_rt.make_const(self, data, device, dtype)) |
|
|
|
|
|
|
|