|
|
@@ -16,7 +16,7 @@ import numpy as np |
|
|
|
|
|
|
|
import megengine as mge |
|
|
|
|
|
|
|
from .._imperative_rt import core2 |
|
|
|
from .._imperative_rt import core2, ops |
|
|
|
from ..ops.builtin import Elemwise, OpDef, RemoteSend |
|
|
|
from ..ops.special import Const |
|
|
|
from ..tensor.core import TensorBase, TensorWrapperBase, apply |
|
|
@@ -211,3 +211,25 @@ class Grad: |
|
|
|
|
|
|
|
def __exit__(self, _1, _2, _3): |
|
|
|
del self._impl |
|
|
|
|
|
|
|
|
|
|
|
class Function(ops.PyOpBase): |
|
|
|
def _default_rule(self, *args): |
|
|
|
ret = self.forward(*args) |
|
|
|
self.__single_output = isinstance(ret, core2.Tensor) |
|
|
|
return ret |
|
|
|
|
|
|
|
def _grad_rule(self, *args): |
|
|
|
return self._default_rule(*args), self.backward |
|
|
|
|
|
|
|
def __call__(self, *args): |
|
|
|
ret = core2.apply(self, *args) |
|
|
|
if self.__single_output: |
|
|
|
(ret,) = ret |
|
|
|
return ret |
|
|
|
|
|
|
|
def __getstate__(self): |
|
|
|
return self.__dict__ |
|
|
|
|
|
|
|
def __setstate__(self, state): |
|
|
|
self.__dict__.update(state) |