|
|
@@ -11,7 +11,8 @@ from typing import Iterable, Union |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from ..core.ops.builtin import Copy |
|
|
|
from ..core._wrap import device as as_device |
|
|
|
from ..core.ops.builtin import Copy, Identity |
|
|
|
from ..core.tensor import Tensor |
|
|
|
from ..core.tensor.core import apply |
|
|
|
from .math import topk as _topk |
|
|
@@ -63,12 +64,12 @@ def accuracy( |
|
|
|
return accs |
|
|
|
|
|
|
|
|
|
|
|
def copy(inp, cn): |
|
|
|
def copy(inp, device=None): |
|
|
|
r""" |
|
|
|
Copies tensor to another device. |
|
|
|
|
|
|
|
:param inp: input tensor. |
|
|
|
:param cn: destination device. |
|
|
|
:param device: destination device. |
|
|
|
|
|
|
|
Examples: |
|
|
|
|
|
|
@@ -88,4 +89,6 @@ def copy(inp, cn): |
|
|
|
|
|
|
|
[1 2 3] |
|
|
|
""" |
|
|
|
return apply(Copy(comp_node=cn), inp)[0] |
|
|
|
if device is None: |
|
|
|
return apply(Identity(), inp)[0] |
|
|
|
return apply(Copy(comp_node=as_device(device).to_c()), inp)[0] |