|
|
@@ -11,6 +11,8 @@ |
|
|
|
import collections |
|
|
|
|
|
|
|
from .core import Tensor as _Tensor |
|
|
|
from .core.ops.builtin import Copy |
|
|
|
from .core.tensor.core import apply |
|
|
|
from .device import get_default_device |
|
|
|
|
|
|
|
|
|
|
@@ -30,6 +32,9 @@ class Tensor(_Tensor): |
|
|
|
def reset_zero(self): |
|
|
|
self *= 0 |
|
|
|
|
|
|
|
def to(self, cn): |
|
|
|
return apply(Copy(comp_node=cn), self)[0] |
|
|
|
|
|
|
|
def __getstate__(self): |
|
|
|
r""" __getstate__ will be called for pickle serialization or deep copy |
|
|
|
""" |
|
|
|