|
|
@@ -16,7 +16,7 @@ import numpy as np |
|
|
|
|
|
|
|
import megengine as mge |
|
|
|
|
|
|
|
from ..ops.builtin import Elemwise, OpDef |
|
|
|
from ..ops.builtin import Elemwise, OpDef, RemoteSend |
|
|
|
from ..ops.special import Const |
|
|
|
from ..tensor.core import TensorBase, TensorWrapperBase, apply |
|
|
|
from ..tensor.function import Function |
|
|
@@ -84,6 +84,9 @@ class Grad: |
|
|
|
# ops forms the computational graph |
|
|
|
self.ops = [] |
|
|
|
|
|
|
|
# save remote_send output for backward |
|
|
|
self.remote_send_cache = [] |
|
|
|
|
|
|
|
self._attached_tensors = weakref.WeakSet() |
|
|
|
self._enabled = True |
|
|
|
|
|
|
@@ -144,6 +147,7 @@ class Grad: |
|
|
|
o.clear() |
|
|
|
for i in self._attached_tensors: |
|
|
|
i._extra_data.pop(self, None) |
|
|
|
self.remote_send_cache = [] |
|
|
|
|
|
|
|
def __exit__(self, *_): |
|
|
|
self._exit() |
|
|
@@ -398,6 +402,8 @@ def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]): |
|
|
|
return |
|
|
|
|
|
|
|
opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs) |
|
|
|
if isinstance(op, RemoteSend): |
|
|
|
manager.remote_send_cache.append(opnode) |
|
|
|
opnode.backward = backward |
|
|
|
|
|
|
|
outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)] |
|
|
|