|
|
@@ -19,7 +19,7 @@ from ..core.autodiff.grad import ( |
|
|
|
) |
|
|
|
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend |
|
|
|
from ..core.tensor.core import apply |
|
|
|
from ..core.tensor.tensor import Tensor, tensor_apply |
|
|
|
from ..core.tensor.tensor import Tensor |
|
|
|
from ..distributed.group import ( |
|
|
|
WORLD, |
|
|
|
Group, |
|
|
@@ -48,7 +48,7 @@ __all__ = [ |
|
|
|
|
|
|
|
@apply.register() |
|
|
|
def _(op: RemoteSend, *args: Tensor): |
|
|
|
ret = tensor_apply(op, *args) |
|
|
|
ret = apply.super(op, *args) |
|
|
|
|
|
|
|
# set extra information |
|
|
|
tracer_set = dict() |
|
|
|