GitOrigin-RevId: 32dd49a23a
tags/v1.8.0
@@ -28,9 +28,6 @@ class AttachSpec: | |||||
__slots__ = "tensor", "callbacks" | __slots__ = "tensor", "callbacks" | ||||
_global_priority = 0 | |||||
class GradManager: | class GradManager: | ||||
r"""GradManager computes gradients or more generally, vector-Jacobian product, by reverse mode | r"""GradManager computes gradients or more generally, vector-Jacobian product, by reverse mode | ||||
automatic differentiation (a.k.a. back propagation). | automatic differentiation (a.k.a. back propagation). | ||||
@@ -127,7 +124,6 @@ class GradManager: | |||||
self._grad = None | self._grad = None | ||||
self._after_backward_callback = [] | self._after_backward_callback = [] | ||||
self._gradients = {} | self._gradients = {} | ||||
self._priority = None | |||||
def attached_tensors(self): | def attached_tensors(self): | ||||
r"""Return attached tensor list from :meth:`attach`.""" | r"""Return attached tensor list from :meth:`attach`.""" | ||||
@@ -299,31 +295,25 @@ class GradManager: | |||||
tensor.grad = grad | tensor.grad = grad | ||||
else: | else: | ||||
tensor.grad += grad | tensor.grad += grad | ||||
if tensor._isscalar() and tensor.grad is not None: | |||||
tensor.grad._setscalar() | |||||
finally: | finally: | ||||
self.release() | self.release() | ||||
backwarding_grad_manager = cache | backwarding_grad_manager = cache | ||||
set_option("record_computing_path", 1) | |||||
pop_scope("backward") | |||||
set_option("record_computing_path", 1) | |||||
pop_scope("backward") | |||||
def record(self): | def record(self): | ||||
r"""Start recording operations | r"""Start recording operations | ||||
After this call, you will be able to call :meth:`backward`. | After this call, you will be able to call :meth:`backward`. | ||||
""" | """ | ||||
global _global_priority | |||||
if self._recording: | if self._recording: | ||||
raise RuntimeError("already recording") | raise RuntimeError("already recording") | ||||
grad = Grad() | grad = Grad() | ||||
self._recording = True | self._recording = True | ||||
self._grad = grad | self._grad = grad | ||||
grad.__enter__() | |||||
for spec in self._attach_specs.values(): | for spec in self._attach_specs.values(): | ||||
self._do_record(spec) | self._do_record(spec) | ||||
if self._priority is None: | |||||
grad._priority = _global_priority | |||||
_global_priority -= 1 | |||||
grad.__enter__() | |||||
def _do_record(self, spec): | def _do_record(self, spec): | ||||
tensor = spec.tensor() | tensor = spec.tensor() | ||||
@@ -331,6 +321,8 @@ class GradManager: | |||||
return | return | ||||
def callback(grad, callbacks=spec.callbacks): | def callback(grad, callbacks=spec.callbacks): | ||||
from ..functional import ones_like | |||||
for cb in callbacks: | for cb in callbacks: | ||||
grad = cb(tensor, grad) | grad = cb(tensor, grad) | ||||
self._gradients[id(tensor)] = grad | self._gradients[id(tensor)] = grad | ||||
@@ -343,14 +335,11 @@ class GradManager: | |||||
After this call, you will not be able to call :meth:`backward`. | After this call, you will not be able to call :meth:`backward`. | ||||
""" | """ | ||||
global _global_priority | |||||
if self._grad is not None: | if self._grad is not None: | ||||
self._grad.__exit__(None, None, None) | self._grad.__exit__(None, None, None) | ||||
self._grad = None | self._grad = None | ||||
self._recording = False | self._recording = False | ||||
self._gradients = dict() | self._gradients = dict() | ||||
if self._priority is None: | |||||
_global_priority += 1 | |||||
def __enter__(self): | def __enter__(self): | ||||
self.record() | self.record() | ||||
@@ -382,15 +371,14 @@ class GradManagerGroup: | |||||
__ror__ = merge_with | __ror__ = merge_with | ||||
def __enter__(self): | def __enter__(self): | ||||
global _global_priority | |||||
_global_priority += 1 | |||||
Grad.stack.append([]) | |||||
Grad.begin_group() | |||||
for gm in self._gms: | for gm in self._gms: | ||||
gm._priority = _global_priority | |||||
gm.record() | gm.record() | ||||
assert gm._grad is not None | |||||
Grad.end_group() | |||||
def __exit__(self, exc_type, exc_val, exc_tb): | def __exit__(self, exc_type, exc_val, exc_tb): | ||||
global _global_priority | |||||
_global_priority -= 1 | |||||
for gm in self._gms: | |||||
for gm in reversed(self._gms): | |||||
gm.release() | gm.release() | ||||
gm._priority = None | |||||
assert gm._grad is None |
@@ -6,17 +6,9 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import functools | |||||
import heapq | |||||
import itertools | |||||
import typing | |||||
import weakref | import weakref | ||||
import numpy as np | |||||
from .._imperative_rt import core2, ops | |||||
from ..ops.builtin import Elemwise, OpDef, RemoteSend | |||||
from ..ops.special import Const | |||||
from .._imperative_rt import core2 | |||||
_grad_count = 0 | _grad_count = 0 | ||||
_grad_manager_dict = weakref.WeakValueDictionary() | _grad_manager_dict = weakref.WeakValueDictionary() | ||||
@@ -36,6 +28,10 @@ class GradKey(core2.GradKey): | |||||
class Grad: | class Grad: | ||||
stack = [] | |||||
grouping = False | |||||
key2grad = weakref.WeakValueDictionary() | |||||
def __init__(self, name=None): | def __init__(self, name=None): | ||||
global _grad_count | global _grad_count | ||||
if name is None: | if name is None: | ||||
@@ -43,15 +39,9 @@ class Grad: | |||||
_grad_count += 1 | _grad_count += 1 | ||||
self._refkeeper = [] | self._refkeeper = [] | ||||
self._impl = GradKey(name) | self._impl = GradKey(name) | ||||
Grad.key2grad[self._impl] = self | |||||
_grad_manager_dict[self._name] = self | _grad_manager_dict[self._name] = self | ||||
@property | |||||
def _priority(self): | |||||
return self._impl.priority | |||||
@_priority.setter | |||||
def _priority(self, priority): | |||||
self._impl.priority = priority | |||||
self._group = [weakref.ref(self)] | |||||
@property | @property | ||||
def _name(self): | def _name(self): | ||||
@@ -70,33 +60,80 @@ class Grad: | |||||
if not isinstance(ys, Sequence): | if not isinstance(ys, Sequence): | ||||
ys = [ys] | ys = [ys] | ||||
if not isinstance(dys, Sequence): | if not isinstance(dys, Sequence): | ||||
dys = [dys] | dys = [dys] | ||||
group = [ref() for ref in self._group] | |||||
for grad in group: | |||||
if grad is self: | |||||
continue | |||||
grad.suppress() | |||||
self._impl.backward(ys, dys) | self._impl.backward(ys, dys) | ||||
for grad in group: | |||||
if grad is self: | |||||
continue | |||||
grad.resume() | |||||
self._refkeeper = None | self._refkeeper = None | ||||
return None | |||||
def __enter__(self): | def __enter__(self): | ||||
ref = weakref.ref(self) | |||||
self._impl.enter() | |||||
if Grad.grouping: | |||||
group = Grad.stack[-1] | |||||
self._group = group | |||||
group.append(ref) | |||||
else: | |||||
Grad.stack.append(self._group) | |||||
return self | return self | ||||
def __exit__(self, _1, _2, _3): | def __exit__(self, _1, _2, _3): | ||||
self._impl.exit() | |||||
self._refkeeper = None | self._refkeeper = None | ||||
del self._impl | |||||
class Function(ops.PyOpBase): | |||||
del Grad.key2grad[self._impl] | |||||
self._impl = None | |||||
self._group.remove(weakref.ref(self)) | |||||
if len(self._group) == 0: | |||||
Grad.stack.remove(self._group) | |||||
@staticmethod | |||||
def begin_group(): | |||||
assert not Grad.grouping | |||||
Grad.grouping = True | |||||
@staticmethod | |||||
def end_group(): | |||||
group = Grad.stack[-1] | |||||
assert len(group) > 0 | |||||
assert Grad.grouping | |||||
Grad.grouping = False | |||||
def suppress(self): | |||||
if self._impl is not None: | |||||
self._impl.suppress() | |||||
def resume(self): | |||||
if self._impl is not None: | |||||
self._impl.resume() | |||||
class Function: | |||||
r"""Defines a block of operations with customizable differentiation. | r"""Defines a block of operations with customizable differentiation. | ||||
The computation should be defined in ``forward`` method, with gradient | The computation should be defined in ``forward`` method, with gradient | ||||
computation defined in ``backward`` method. | computation defined in ``backward`` method. | ||||
Each instance of ``Function`` should be used only once during forwardding. | Each instance of ``Function`` should be used only once during forwardding. | ||||
Examples: | Examples: | ||||
.. code-block:: | .. code-block:: | ||||
class Sigmoid(Function): | class Sigmoid(Function): | ||||
def forward(self, x): | def forward(self, x): | ||||
y = 1 / (1 + F.exp(-x)) | y = 1 / (1 + F.exp(-x)) | ||||
@@ -115,7 +152,7 @@ class Function(ops.PyOpBase): | |||||
Returns: | Returns: | ||||
a tuple of Tensor or a single Tensor. | a tuple of Tensor or a single Tensor. | ||||
Note: | Note: | ||||
* This method should return a tuple of Tensor or a single Tensor representing the output | * This method should return a tuple of Tensor or a single Tensor representing the output | ||||
of the function. | of the function. | ||||
@@ -128,7 +165,7 @@ class Function(ops.PyOpBase): | |||||
Args: | Args: | ||||
output_grads: gradients of outputs that are returned by :meth:`forward`. | output_grads: gradients of outputs that are returned by :meth:`forward`. | ||||
Note: | Note: | ||||
* In case when some tensors of outputs are not related to loss function, the corresponding | * In case when some tensors of outputs are not related to loss function, the corresponding | ||||
values in ``output_grads`` would be ``None``. | values in ``output_grads`` would be ``None``. | ||||
@@ -148,10 +185,40 @@ class Function(ops.PyOpBase): | |||||
return self._default_rule(*args), self.backward | return self._default_rule(*args), self.backward | ||||
def __call__(self, *args): | def __call__(self, *args): | ||||
ret = core2.apply(self, *args) | |||||
for arg in args: | |||||
if not isinstance(arg, core2.Tensor): | |||||
raise TypeError( | |||||
"op Function expect type Tensor as inputs, got {}".format(type(arg)) | |||||
) | |||||
grad_key = core2.get_grad_key(args) | |||||
if grad_key is None: | |||||
return self._default_rule(*args) | |||||
grad = Grad.key2grad[grad_key] | |||||
group = [ref() for ref in grad._group] | |||||
for grad in group: | |||||
grad.suppress() | |||||
outputs, backward = self._grad_rule(*args) | |||||
for grad in reversed(group): | |||||
grad.resume() | |||||
def normalized_backward(*output_grads): | |||||
input_grads = backward(*output_grads) | |||||
if isinstance(input_grads, core2.Tensor) or input_grads is None: | |||||
input_grads = (input_grads,) | |||||
return input_grads | |||||
if self.__single_output: | if self.__single_output: | ||||
(ret,) = ret | |||||
return ret | |||||
outputs = (outputs,) | |||||
for grad in reversed(group): | |||||
if grad._impl is None: | |||||
continue | |||||
outputs = core2.set_grad(grad._impl, normalized_backward, args, outputs) | |||||
if self.__single_output: | |||||
(outputs,) = outputs | |||||
return outputs | |||||
def __getstate__(self): | def __getstate__(self): | ||||
return self.__dict__ | return self.__dict__ | ||||
@@ -26,7 +26,6 @@ from .utils import ( | |||||
convert_inputs, | convert_inputs, | ||||
isscalar, | isscalar, | ||||
make_shape_tuple, | make_shape_tuple, | ||||
setscalar, | |||||
) | ) | ||||
_ElwMod = builtin.Elemwise.Mode | _ElwMod = builtin.Elemwise.Mode | ||||
@@ -34,14 +33,7 @@ _ElwMod = builtin.Elemwise.Mode | |||||
def _elwise_apply(args, mode): | def _elwise_apply(args, mode): | ||||
op = builtin.Elemwise(mode) | op = builtin.Elemwise(mode) | ||||
_isscalar = True | |||||
for i in args: | |||||
if isscalar(i) == False: | |||||
_isscalar = False | |||||
break | |||||
(result,) = apply(op, *args) | (result,) = apply(op, *args) | ||||
if _isscalar: | |||||
setscalar(result) | |||||
return result | return result | ||||
@@ -203,8 +195,6 @@ def _remove_axis(inp: Tensor, axis) -> Tensor: | |||||
op = builtin.RemoveAxis(axis=axis) | op = builtin.RemoveAxis(axis=axis) | ||||
(result,) = apply(op, inp) | (result,) = apply(op, inp) | ||||
if len(axis) == inp.ndim: | |||||
setscalar(result) | |||||
return result | return result | ||||
@@ -221,6 +211,7 @@ def _reduce(mode): | |||||
op = builtin.Reduce(mode=mode, axis=0) | op = builtin.Reduce(mode=mode, axis=0) | ||||
(result,) = apply(op, data) | (result,) = apply(op, data) | ||||
result = _remove_axis(result, 0) | |||||
elif isinstance(axis, collections.abc.Iterable): | elif isinstance(axis, collections.abc.Iterable): | ||||
axis = _normalize_axis(self.ndim, axis, reverse=True) | axis = _normalize_axis(self.ndim, axis, reverse=True) | ||||
for ai in axis: | for ai in axis: | ||||
@@ -239,8 +230,6 @@ def _reduce(mode): | |||||
if self.dtype == np.bool_: | if self.dtype == np.bool_: | ||||
if mode in ["min", "max"]: | if mode in ["min", "max"]: | ||||
result = result.astype("bool") | result = result.astype("bool") | ||||
if axis is None or self.ndim == 1: | |||||
setscalar(result) | |||||
return result | return result | ||||
return f | return f | ||||
@@ -457,7 +446,6 @@ class ArrayMethodMixin(abc.ABC): | |||||
len(args) == 0 | len(args) == 0 | ||||
), "transpose for scalar does not accept additional args" | ), "transpose for scalar does not accept additional args" | ||||
ret = self.to(self.device) | ret = self.to(self.device) | ||||
setscalar(ret) | |||||
return ret | return ret | ||||
if not args: | if not args: | ||||
args = range(self.ndim)[::-1] | args = range(self.ndim)[::-1] | ||||
@@ -111,7 +111,6 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||||
if not isinstance(tuple_val, tuple): | if not isinstance(tuple_val, tuple): | ||||
tuple_val = (tuple_val,) | tuple_val = (tuple_val,) | ||||
ndim_indexed = 0 | ndim_indexed = 0 | ||||
ndim_indexed_scalar = 0 | |||||
for i in tuple_val: | for i in tuple_val: | ||||
if not i is Ellipsis: | if not i is Ellipsis: | ||||
ndim_indexed += ( | ndim_indexed += ( | ||||
@@ -119,14 +118,6 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||||
if hasattr(i, "dtype") and i.dtype == np.bool_ and hasattr(i, "ndim") | if hasattr(i, "dtype") and i.dtype == np.bool_ and hasattr(i, "ndim") | ||||
else 1 | else 1 | ||||
) | ) | ||||
if isscalar(i): | |||||
ndim_indexed_scalar += 1 | |||||
ret_scalar = False | |||||
try: | |||||
ret_scalar = ndim_indexed_scalar == inp.ndim | |||||
except ValueError: | |||||
# inp.ndim is unknown | |||||
pass | |||||
else: | else: | ||||
if ndim_indexed > inp.ndim: | if ndim_indexed > inp.ndim: | ||||
raise IndexError( | raise IndexError( | ||||
@@ -221,7 +212,7 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||||
items.append(item) | items.append(item) | ||||
if new_axes: | if new_axes: | ||||
raise IndexError("newaxis is not allowed here") | raise IndexError("newaxis is not allowed here") | ||||
return inp, tensors, items, use_subtensor, ret_scalar | |||||
return inp, tensors, items, use_subtensor | |||||
def try_condtake(tensor, index): | def try_condtake(tensor, index): | ||||
@@ -247,14 +238,12 @@ def getitem(tensor, index): | |||||
try_result = try_condtake(tensor, index) | try_result = try_condtake(tensor, index) | ||||
if len(try_result) == 2: | if len(try_result) == 2: | ||||
return try_result[0] | return try_result[0] | ||||
tensor, tensors, items, use_subtensor, ret_scalar = unpack_getitem(tensor, index) | |||||
tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index) | |||||
if use_subtensor: | if use_subtensor: | ||||
op = builtin.Subtensor(items=items) | op = builtin.Subtensor(items=items) | ||||
else: | else: | ||||
op = builtin.IndexingMultiAxisVec(items=items) | op = builtin.IndexingMultiAxisVec(items=items) | ||||
(result,) = apply(op, tensor, *tensors) | (result,) = apply(op, tensor, *tensors) | ||||
if ret_scalar: | |||||
result._setscalar() | |||||
return result | return result | ||||
@@ -266,7 +255,7 @@ def setitem(tensor, index, value): | |||||
tensor = tensor.reshape(-1) | tensor = tensor.reshape(-1) | ||||
if not isinstance(value, (Tensor, SymbolVar)): | if not isinstance(value, (Tensor, SymbolVar)): | ||||
(value,) = Const(value, dtype=tensor.dtype, device=tensor.device)(tensor) | (value,) = Const(value, dtype=tensor.dtype, device=tensor.device)(tensor) | ||||
tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) | |||||
tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index) | |||||
if use_subtensor: | if use_subtensor: | ||||
op = builtin.Subtensor(items=items) | op = builtin.Subtensor(items=items) | ||||
else: | else: | ||||
@@ -17,6 +17,7 @@ import numpy as np | |||||
from .. import _imperative_rt | from .. import _imperative_rt | ||||
from .._imperative_rt import GraphOptimizeOptions, SerializationFormat | from .._imperative_rt import GraphOptimizeOptions, SerializationFormat | ||||
from .._imperative_rt.core2 import apply | |||||
from .._wrap import as_device | from .._wrap import as_device | ||||
from ..ops.builtin import OpDef | from ..ops.builtin import OpDef | ||||
@@ -126,9 +127,8 @@ class Graph(_imperative_rt.ComputingGraph): | |||||
class VarNode: | class VarNode: | ||||
def __init__(self, node: _imperative_rt.VarNode, isscalar=False): | |||||
def __init__(self, node: _imperative_rt.VarNode): | |||||
self._node = node | self._node = node | ||||
self._isscalar = isscalar | |||||
if hasattr(self.graph, "_var_cache"): | if hasattr(self.graph, "_var_cache"): | ||||
self.graph._var_cache[node] = self | self.graph._var_cache[node] = self | ||||
@@ -530,9 +530,6 @@ def _unwrap(x): | |||||
def apply_normal_varnode(op: OpDef, *args: VarNode): | def apply_normal_varnode(op: OpDef, *args: VarNode): | ||||
# for PyOp like RemoteSend/Recv | |||||
if getattr(op, "op", None): | |||||
op = op.op | |||||
outputs = _imperative_rt.invoke_op(op, _unwrap(args)) | outputs = _imperative_rt.invoke_op(op, _unwrap(args)) | ||||
return _wrap(outputs) | return _wrap(outputs) | ||||
@@ -51,10 +51,7 @@ def concatenate(inputs, axis=0, *, device=None): | |||||
def astype(x, dtype): | def astype(x, dtype): | ||||
dtype = np.dtype(dtype) | dtype = np.dtype(dtype) | ||||
if not is_dtype_equal(x.dtype, dtype): | if not is_dtype_equal(x.dtype, dtype): | ||||
isscalar = x._isscalar() | |||||
(x,) = apply(builtin.TypeCvt(dtype=dtype), x) | (x,) = apply(builtin.TypeCvt(dtype=dtype), x) | ||||
if isscalar: | |||||
x._setscalar() | |||||
return x | return x | ||||
@@ -129,13 +126,6 @@ def isscalar(x): | |||||
return np.isscalar(x) | return np.isscalar(x) | ||||
def setscalar(x): | |||||
if isinstance(x, (Tensor, SymbolVar)): | |||||
x._setscalar() | |||||
else: | |||||
raise NotImplementedError("Unsupport type {}".format(type(x))) | |||||
def astensor1d(x, *reference, dtype=None, device=None): | def astensor1d(x, *reference, dtype=None, device=None): | ||||
"""Convert something to 1D tensor. Support following types | """Convert something to 1D tensor. Support following types | ||||
@@ -237,6 +227,7 @@ for name, mode in [ | |||||
("**", "pow"), | ("**", "pow"), | ||||
("max", "max"), | ("max", "max"), | ||||
("additive", "add"), | ("additive", "add"), | ||||
("exp", "EXP"), | |||||
]: | ]: | ||||
_opr_map[(name, 2)] = builtin.Elemwise(mode=mode) | _opr_map[(name, 2)] = builtin.Elemwise(mode=mode) | ||||
@@ -13,7 +13,7 @@ import numpy as np | |||||
from ..core._imperative_rt.core2 import apply | from ..core._imperative_rt.core2 import apply | ||||
from ..core.autodiff.grad import Function, _grad_manager_dict | from ..core.autodiff.grad import Function, _grad_manager_dict | ||||
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | ||||
from ..core.tensor.utils import isscalar, setscalar | |||||
from ..core.tensor.utils import isscalar | |||||
from ..device import get_default_device, what_is_xpu | from ..device import get_default_device, what_is_xpu | ||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
from . import group | from . import group | ||||
@@ -72,15 +72,6 @@ def collective_comm(inp, mode, group, device): | |||||
) | ) | ||||
(result,) = apply(op, inp) | (result,) = apply(op, inp) | ||||
# assume all workers have homogeneous shape | # assume all workers have homogeneous shape | ||||
if mode in ( | |||||
CollectiveComm.Mode.REDUCE_SUM, | |||||
CollectiveComm.Mode.BROADCAST, | |||||
CollectiveComm.Mode.ALL_REDUCE_SUM, | |||||
CollectiveComm.Mode.ALL_REDUCE_MAX, | |||||
CollectiveComm.Mode.ALL_REDUCE_MIN, | |||||
): | |||||
if isscalar(inp): | |||||
setscalar(result) | |||||
return result | return result | ||||
@@ -190,8 +181,7 @@ def reduce_sum( | |||||
# Rank 0 # output: None | # Rank 0 # output: None | ||||
# Rank 1 # output: Tensor([1]) | # Rank 1 # output: Tensor([1]) | ||||
""" | """ | ||||
op = _ReduceSum(group, device) | |||||
(out,) = apply(op, inp) | |||||
out = _ReduceSum(group, device)(inp) | |||||
if group.rank == 0: | if group.rank == 0: | ||||
return out | return out | ||||
@@ -258,8 +248,7 @@ def broadcast( | |||||
_bcast_tracer_state(group, inp) | _bcast_tracer_state(group, inp) | ||||
op = _Broadcast(group, device) | |||||
(out,) = apply(op, inp) | |||||
out = _Broadcast(group, device)(inp) | |||||
return out | return out | ||||
@@ -604,8 +593,7 @@ def gather( | |||||
inp.shape | inp.shape | ||||
) | ) | ||||
op = _Gather(group, device) | |||||
(out,) = apply(op, inp) | |||||
out = _Gather(group, device)(inp) | |||||
if group.rank == 0: | if group.rank == 0: | ||||
if axis == 0: | if axis == 0: | ||||
@@ -708,8 +696,7 @@ def scatter( | |||||
+ [_ for _ in range(axis + 1, inp.ndim + 1)] | + [_ for _ in range(axis + 1, inp.ndim + 1)] | ||||
) | ) | ||||
inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape) | inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape) | ||||
op = _Scatter(group, device) | |||||
(out,) = apply(op, inp) | |||||
out = _Scatter(group, device)(inp) | |||||
return out | return out | ||||
@@ -832,7 +819,7 @@ class _RemoteRecv(Function): | |||||
self.op = op | self.op = op | ||||
def forward(self, dummy): | def forward(self, dummy): | ||||
return apply(self.op, dummy) | |||||
return apply(self.op, dummy)[0] | |||||
def backward(self, grad): | def backward(self, grad): | ||||
get_client().bcast_val(grad is not None, self.op.key, 2) | get_client().bcast_val(grad is not None, self.op.key, 2) | ||||
@@ -871,7 +858,7 @@ def remote_send(inp: Tensor, dest_rank: int): | |||||
op.addr, op.port = get_mm_server_addr() | op.addr, op.port = get_mm_server_addr() | ||||
op.rank_to = dest_rank | op.rank_to = dest_rank | ||||
op.backend = _backend() | op.backend = _backend() | ||||
(out,) = apply(_RemoteSend(op), inp) | |||||
out = _RemoteSend(op)(inp) | |||||
_save_output_for_autodiff(inp, out) | _save_output_for_autodiff(inp, out) | ||||
@@ -912,11 +899,6 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor | |||||
inp = Tensor(0, device=device) | inp = Tensor(0, device=device) | ||||
_bcast_tracer_state(group, inp) | _bcast_tracer_state(group, inp) | ||||
_isscalar = False | |||||
if len(shape) == 0: | |||||
shape = (1,) | |||||
_isscalar = True | |||||
op = RemoteRecv() | op = RemoteRecv() | ||||
op.key = group.key | op.key = group.key | ||||
op.cn = device | op.cn = device | ||||
@@ -926,7 +908,5 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor | |||||
op.rank_from = src_rank | op.rank_from = src_rank | ||||
op.backend = _backend() | op.backend = _backend() | ||||
(ret,) = apply(_RemoteRecv(op), inp) | |||||
if _isscalar: | |||||
setscalar(ret) | |||||
ret = _RemoteRecv(op)(inp) | |||||
return ret | return ret |
@@ -67,9 +67,6 @@ def param_pack_split(inp: Tensor, offsets: list, shapes: list): | |||||
op.offsets = offsets | op.offsets = offsets | ||||
op.shapes = [s or (1,) for s in shapes] | op.shapes = [s or (1,) for s in shapes] | ||||
outputs = apply(op, inp) | outputs = apply(op, inp) | ||||
for s, x in zip(shapes, outputs): | |||||
if not s: | |||||
x._setscalar() | |||||
return outputs | return outputs | ||||
@@ -1,25 +0,0 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from ..core._imperative_rt.core2 import ( | |||||
set_allow_higher_order_directive as _set_allow_higher_order_directive, | |||||
) | |||||
__all__ = [ | |||||
"enable_higher_order_directive", | |||||
"disable_higher_order_directive", | |||||
] | |||||
def enable_higher_order_directive(): | |||||
_set_allow_higher_order_directive(True) | |||||
def disable_higher_order_directive(): | |||||
_set_allow_higher_order_directive(False) |
@@ -12,8 +12,5 @@ from ..core.ops.builtin import InplaceAdd | |||||
def _inplace_add_(dest, delta, alpha, beta): | def _inplace_add_(dest, delta, alpha, beta): | ||||
isscalar = dest._isscalar() | |||||
dest._reset(apply(InplaceAdd(), dest, delta, alpha, beta)[0]) | dest._reset(apply(InplaceAdd(), dest, delta, alpha, beta)[0]) | ||||
if isscalar: | |||||
dest._setscalar() | |||||
return dest | return dest |
@@ -19,7 +19,7 @@ from ..core.ops import builtin | |||||
from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt | from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt | ||||
from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
from ..core.tensor import amp | from ..core.tensor import amp | ||||
from ..core.tensor.utils import _normalize_axis, cast_tensors, setscalar, subgraph | |||||
from ..core.tensor.utils import _normalize_axis, cast_tensors, subgraph | |||||
from ..jit import exclude_from_trace | from ..jit import exclude_from_trace | ||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
from ..utils.deprecation import deprecated_kwargs_default | from ..utils.deprecation import deprecated_kwargs_default | ||||
@@ -1149,7 +1149,6 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor: | |||||
inp1.ndim <= 1 and inp2.ndim <= 1 | inp1.ndim <= 1 and inp2.ndim <= 1 | ||||
), "Input tensors for dot must be 1-dimensional or scalar" | ), "Input tensors for dot must be 1-dimensional or scalar" | ||||
(result,) = apply(op, inp1, inp2) | (result,) = apply(op, inp1, inp2) | ||||
setscalar(result) | |||||
return result | return result | ||||
@@ -1200,5 +1199,4 @@ def _check_non_finite(inps: Iterable[Tensor], scale=1.0) -> Tensor: | |||||
for i in range(len(inps)): | for i in range(len(inps)): | ||||
inps[i]._reset(oups[i]) | inps[i]._reset(oups[i]) | ||||
out._setscalar() | |||||
return out | return out |
@@ -35,7 +35,6 @@ from ..core.tensor.utils import ( | |||||
cast_tensors, | cast_tensors, | ||||
convert_single_value, | convert_single_value, | ||||
make_shape_tuple, | make_shape_tuple, | ||||
setscalar, | |||||
subgraph, | subgraph, | ||||
) | ) | ||||
from ..device import get_default_device | from ..device import get_default_device | ||||
@@ -972,13 +972,6 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: | |||||
) | ) | ||||
axis = sorted(axis) | axis = sorted(axis) | ||||
assert axis, "axis could not be empty" | assert axis, "axis could not be empty" | ||||
if inp._isscalar(): | |||||
assert axis[0] == 0, "invalid axis {} for ndim 0".format(axis[0]) | |||||
if len(axis) == 1: | |||||
inp = copy(inp, device=None) | |||||
inp._unsetscalar() | |||||
return inp | |||||
axis = axis[1:] | |||||
op = builtin.AddAxis(axis=axis) | op = builtin.AddAxis(axis=axis) | ||||
(result,) = apply(op, inp) | (result,) = apply(op, inp) | ||||
return result | return result | ||||
@@ -1164,8 +1157,6 @@ def repeat(inp: Tensor, repeats: int, axis: Optional[int] = None): | |||||
if axis is None: | if axis is None: | ||||
inp = inp.reshape(-1) # flatten | inp = inp.reshape(-1) # flatten | ||||
axis = 0 | axis = 0 | ||||
if inp._isscalar(): | |||||
inp._unsetscalar() | |||||
shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device) | shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device) | ||||
# assume inp.ndim is not changed during trace | # assume inp.ndim is not changed during trace | ||||
max_axis = len(shape) - 1 | max_axis = len(shape) - 1 | ||||
@@ -6,19 +6,7 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from ..core._imperative_rt.core2 import ( | |||||
set_cpp_apply_const_with_tracing, | |||||
set_cpp_apply_with_tracing, | |||||
) | |||||
from .dtr_config import DTRConfig | from .dtr_config import DTRConfig | ||||
from .graph_opt_config import GraphOptimizationConfig | from .graph_opt_config import GraphOptimizationConfig | ||||
from .sublinear_memory_config import SublinearMemoryConfig | from .sublinear_memory_config import SublinearMemoryConfig | ||||
from .tracing import ( | |||||
apply_const_with_tracing, | |||||
apply_with_tracing, | |||||
exclude_from_trace, | |||||
trace, | |||||
) | |||||
set_cpp_apply_with_tracing(apply_with_tracing) | |||||
set_cpp_apply_const_with_tracing(apply_const_with_tracing) | |||||
from .tracing import TraceError, exclude_from_trace, trace |
@@ -111,6 +111,7 @@ class Module(metaclass=ABCMeta): | |||||
# used for profiler and automatic naming | # used for profiler and automatic naming | ||||
self._name = None | self._name = None | ||||
self._short_name = None | |||||
@abstractmethod | @abstractmethod | ||||
def forward(self, inputs): | def forward(self, inputs): | ||||
@@ -137,7 +138,7 @@ class Module(metaclass=ABCMeta): | |||||
return HookHandler(self._forward_hooks, hook) | return HookHandler(self._forward_hooks, hook) | ||||
def __call__(self, *inputs, **kwargs): | def __call__(self, *inputs, **kwargs): | ||||
AutoNaming.push_scope(self.name if self.name is not None else self._name) | |||||
AutoNaming.push_scope(self.name if self.name is not None else self._short_name) | |||||
for hook in self._forward_pre_hooks.values(): | for hook in self._forward_pre_hooks.values(): | ||||
modified_inputs = hook(self, inputs) | modified_inputs = hook(self, inputs) | ||||
if modified_inputs is not None: | if modified_inputs is not None: | ||||
@@ -641,15 +642,43 @@ class Module(metaclass=ABCMeta): | |||||
else: | else: | ||||
if modules is not None and name in modules: | if modules is not None and name in modules: | ||||
modules.remove(name) | modules.remove(name) | ||||
for k, v in _expand_structure(name, value): | |||||
if not v._name: | |||||
v._name = k | |||||
elif v._name != k: | |||||
def append_name(prefix, name): | |||||
if prefix is None or prefix == "": | |||||
return name | |||||
return prefix + "." + name | |||||
def set_name(parent, prefix, name, obj): | |||||
if isinstance(obj, Tensor): | |||||
assert obj.name is not None | |||||
if obj.name != "": | |||||
name = obj.name | |||||
full_name = append_name(prefix, name) | |||||
if obj._short_name and obj._short_name != name: | |||||
logger.warning( | logger.warning( | ||||
"try setting the submodule `{}` to `{}`'s new attribute `{}`, its name `{}` will remain unchanged".format( | "try setting the submodule `{}` to `{}`'s new attribute `{}`, its name `{}` will remain unchanged".format( | ||||
type(v), type(self), k, v._name | |||||
obj._short_name, type(parent), name, obj._short_name | |||||
) | ) | ||||
) | ) | ||||
return | |||||
if isinstance(obj, Tensor): | |||||
obj._prefix = prefix | |||||
obj._name = full_name | |||||
obj._short_name = name | |||||
obj._set_name(obj._name) | |||||
return obj._name | |||||
elif isinstance(obj, Module): | |||||
obj._name = full_name | |||||
obj._short_name = name | |||||
for k, v in obj._flatten(recursive=False, with_key=True): | |||||
set_name(obj, full_name, k, v) | |||||
return obj._name | |||||
else: | |||||
assert False | |||||
for k, v in _expand_structure(name, value): | |||||
prefix = self._name if self._name else self.name | |||||
set_name(self, prefix, k, v) | |||||
super().__setattr__(name, value) | super().__setattr__(name, value) | ||||
def __delattr__(self, name: str): | def __delattr__(self, name: str): | ||||
@@ -14,6 +14,7 @@ from numpy.random import MT19937 | |||||
from .. import Tensor | from .. import Tensor | ||||
from ..core._imperative_rt.core2 import apply | from ..core._imperative_rt.core2 import apply | ||||
from ..core._imperative_rt.core2 import sync as _sync | |||||
from ..core._imperative_rt.ops import delete_rng_handle as _delete_rng_handle | from ..core._imperative_rt.ops import delete_rng_handle as _delete_rng_handle | ||||
from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed | from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed | ||||
from ..core._imperative_rt.ops import ( | from ..core._imperative_rt.ops import ( | ||||
@@ -650,6 +651,10 @@ class RNG: | |||||
def __del__(self): | def __del__(self): | ||||
if self._handle != 0: | if self._handle != 0: | ||||
# RNG op might execute after handle released due to async dispatch, so | |||||
# we need sync before delete a handle to avoid memory leak or | |||||
# use-after-free | |||||
_sync() | |||||
_delete_rng_handle(self._handle) | _delete_rng_handle(self._handle) | ||||
@@ -12,7 +12,7 @@ import numpy as np | |||||
from .core._imperative_rt import CompNode | from .core._imperative_rt import CompNode | ||||
from .core._imperative_rt.core2 import Tensor as _Tensor | from .core._imperative_rt.core2 import Tensor as _Tensor | ||||
from .core._imperative_rt.core2 import apply | |||||
from .core._imperative_rt.core2 import apply, set_py_tensor_type | |||||
from .core._trace_option import use_symbolic_shape | from .core._trace_option import use_symbolic_shape | ||||
from .core._wrap import as_device | from .core._wrap import as_device | ||||
from .core.ops.builtin import Copy, GetVarShape | from .core.ops.builtin import Copy, GetVarShape | ||||
@@ -20,7 +20,6 @@ from .core.tensor.array_method import ArrayMethodMixin | |||||
from .device import _valid_device, get_default_device | from .device import _valid_device, get_default_device | ||||
from .logger import get_logger | from .logger import get_logger | ||||
from .utils.deprecation import deprecated | from .utils.deprecation import deprecated | ||||
from .utils.naming import AutoNaming | |||||
logger = get_logger(__name__) | logger = get_logger(__name__) | ||||
@@ -40,6 +39,10 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||||
grad = None | grad = None | ||||
dmap_callback = None | dmap_callback = None | ||||
_qparams = None | _qparams = None | ||||
_custom_name = "" | |||||
_name = None | |||||
_short_name = None | |||||
_prefix = None | |||||
def __new__( | def __new__( | ||||
cls, | cls, | ||||
@@ -81,9 +84,15 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||||
device: str = None, | device: str = None, | ||||
is_const: bool = False, | is_const: bool = False, | ||||
no_cache: bool = False, | no_cache: bool = False, | ||||
name: str = None, | |||||
name: str = "", | |||||
): | ): | ||||
pass | |||||
if name is None: | |||||
name = "" | |||||
self._custom_name = name | |||||
self._name = name | |||||
self._short_name = name | |||||
self._set_name(self._name) | |||||
self._prefix = None | |||||
@property | @property | ||||
def shape(self) -> Union[tuple, "Tensor"]: | def shape(self) -> Union[tuple, "Tensor"]: | ||||
@@ -151,12 +160,13 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||||
@property | @property | ||||
def name(self): | def name(self): | ||||
return self.c_name | |||||
return self._custom_name | |||||
@name.setter | @name.setter | ||||
def name(self, name): | def name(self, name): | ||||
self.c_name = name | |||||
AutoNaming.record_var_name(self._mixin_handle, name) | |||||
self._custom_name = name | |||||
self._name = self._prefix + "." + name if self._prefix else name | |||||
self._set_name(self._name) | |||||
@deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0") | @deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0") | ||||
def set_value(self, value): | def set_value(self, value): | ||||
@@ -224,6 +234,9 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||||
self._qparams = qparams | self._qparams = qparams | ||||
set_py_tensor_type(Tensor) | |||||
tensor = Tensor | tensor = Tensor | ||||
@@ -6,7 +6,6 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from ..core._imperative_rt.core2 import set_cpp_apply_module_trace | |||||
from . import compat | from . import compat | ||||
from ._passes import optimize | from ._passes import optimize | ||||
from .pytree import register_supported_type | from .pytree import register_supported_type | ||||
@@ -14,14 +13,12 @@ from .tm_config import disable_default_checker, enable_expr_checker | |||||
from .traced_module import ( | from .traced_module import ( | ||||
TracedModule, | TracedModule, | ||||
_register_all_builtin_module, | _register_all_builtin_module, | ||||
cpp_apply_module_trace, | |||||
register_as_builtin, | register_as_builtin, | ||||
trace_module, | trace_module, | ||||
wrap, | wrap, | ||||
) | ) | ||||
_register_all_builtin_module() | _register_all_builtin_module() | ||||
set_cpp_apply_module_trace(cpp_apply_module_trace) | |||||
__all__ = [ | __all__ = [ | ||||
"register_as_builtin", | "register_as_builtin", | ||||
@@ -13,7 +13,6 @@ import numpy as np | |||||
from ..core._imperative_rt.core2 import apply | from ..core._imperative_rt.core2 import apply | ||||
from ..core._imperative_rt.ops import ROIAlign, ROIPooling | from ..core._imperative_rt.ops import ROIAlign, ROIPooling | ||||
from ..core.ops.builtin import Copy | from ..core.ops.builtin import Copy | ||||
from ..core.tensor.utils import isscalar, setscalar | |||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
from .tm_config import _exclude_from_trace | from .tm_config import _exclude_from_trace | ||||
@@ -70,8 +69,6 @@ class TracedModuleChecker: | |||||
self.current_node2values()[node] = apply( | self.current_node2values()[node] = apply( | ||||
Copy(comp_node=value.device), value | Copy(comp_node=value.device), value | ||||
)[0] | )[0] | ||||
if isscalar(value): | |||||
setscalar(self.current_node2values()[node]) | |||||
def check_apply_special_cases(self, opdef, num_outputs): | def check_apply_special_cases(self, opdef, num_outputs): | ||||
indexs = list(range(num_outputs)) | indexs = list(range(num_outputs)) | ||||
@@ -20,6 +20,7 @@ from ..core._imperative_rt.core2 import Tensor as RawTensor | |||||
from ..core._imperative_rt.core2 import ( | from ..core._imperative_rt.core2 import ( | ||||
apply, | apply, | ||||
is_tracing_module, | is_tracing_module, | ||||
set_module_trace_hook, | |||||
set_module_tracing, | set_module_tracing, | ||||
unset_module_tracing, | unset_module_tracing, | ||||
) | ) | ||||
@@ -605,8 +606,7 @@ class Apply(Expr): | |||||
def apply_module_trace_hook(cls, opdef, *inputs): | def apply_module_trace_hook(cls, opdef, *inputs): | ||||
for i in inputs: | for i in inputs: | ||||
node = NodeMixin.get(i, None) | node = NodeMixin.get(i, None) | ||||
if node is None: # capture as constant | |||||
NodeMixin.wrap_safe(i, Constant.make(i)) | |||||
assert node is not None | |||||
if isinstance(opdef, FakeQuant): | if isinstance(opdef, FakeQuant): | ||||
inp_nodes = [NodeMixin.get(inputs[0])] | inp_nodes = [NodeMixin.get(inputs[0])] | ||||
@@ -805,3 +805,12 @@ class Constant(Expr): | |||||
if isinstance(v, _ModuleState): | if isinstance(v, _ModuleState): | ||||
state[k] = v.to_module() | state[k] = v.to_module() | ||||
self.__dict__.update(state) | self.__dict__.update(state) | ||||
def _module_trace_capture(value): | |||||
node = Constant.make(value) | |||||
NodeMixin.wrap_safe(value, node) | |||||
return node | |||||
set_module_trace_hook(Apply.apply_module_trace_hook) |
@@ -101,9 +101,7 @@ BUILTIN_TENSOR_WRAP_METHOD = [ | |||||
"requires_grad", | "requires_grad", | ||||
"_reset", | "_reset", | ||||
"_isscalar", | "_isscalar", | ||||
"_setscalar", | |||||
"_tuple_shape", | "_tuple_shape", | ||||
"_unsetscalar", | |||||
] | ] | ||||
@@ -43,7 +43,6 @@ from ..core._imperative_rt.core2 import ( | |||||
) | ) | ||||
from ..core._trace_option import set_symbolic_shape | from ..core._trace_option import set_symbolic_shape | ||||
from ..core.ops.builtin import Copy | from ..core.ops.builtin import Copy | ||||
from ..core.tensor.utils import isscalar, setscalar | |||||
from ..module import Module | from ..module import Module | ||||
from ..module import external as MExternal | from ..module import external as MExternal | ||||
from ..module.qat import QATModule | from ..module.qat import QATModule | ||||
@@ -1295,12 +1294,9 @@ def _wrapped_function(orig_func): | |||||
return orig_func(*args, **kwargs) | return orig_func(*args, **kwargs) | ||||
if isinstance(args[1], RawTensor): | if isinstance(args[1], RawTensor): | ||||
node = NodeMixin.get(inputs[1]) | node = NodeMixin.get(inputs[1]) | ||||
is_scalar = isscalar(inputs[1]) | |||||
inputs[1] = apply( | inputs[1] = apply( | ||||
Copy(comp_node=inputs[1].device), Tensor(inputs[1]) | Copy(comp_node=inputs[1].device), Tensor(inputs[1]) | ||||
)[0] | )[0] | ||||
if is_scalar: | |||||
setscalar(inputs[1]) | |||||
# copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, | # copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, | ||||
# which will cause they have same _NodeMixin__node in tracing. | # which will cause they have same _NodeMixin__node in tracing. | ||||
NodeMixin.wrap_safe(inputs[1], node) | NodeMixin.wrap_safe(inputs[1], node) | ||||
@@ -2468,8 +2464,8 @@ def trace_module( | |||||
try: | try: | ||||
net_name = mod._name if mod._name else mod.__class__.__name__ | net_name = mod._name if mod._name else mod.__class__.__name__ | ||||
use_sym_shape = set_symbolic_shape(True) | use_sym_shape = set_symbolic_shape(True) | ||||
set_module_tracing() | |||||
set_active_module_tracer(module_tracer(_wrapped_function)) | set_active_module_tracer(module_tracer(_wrapped_function)) | ||||
set_module_tracing() | |||||
for cls in [Expr, Node]: | for cls in [Expr, Node]: | ||||
cls._set_next_id(0) | cls._set_next_id(0) | ||||
with active_module_tracer().patcher: | with active_module_tracer().patcher: | ||||
@@ -2518,9 +2514,9 @@ def trace_module( | |||||
return traced_mod | return traced_mod | ||||
finally: | finally: | ||||
set_symbolic_shape(use_sym_shape) | set_symbolic_shape(use_sym_shape) | ||||
set_active_module_tracer(None) | |||||
unset_module_tracing() | unset_module_tracing() | ||||
for t in mod.tensors(recursive=True): | for t in mod.tensors(recursive=True): | ||||
NodeMixin.clear_node(t) | NodeMixin.clear_node(t) | ||||
for t in inputs: | for t in inputs: | ||||
NodeMixin.clear_node(t) | NodeMixin.clear_node(t) | ||||
set_active_module_tracer(None) |
@@ -137,6 +137,11 @@ class Profiler(ContextDecorator): | |||||
get_logger().info("process {} generating {}".format(self._pid, format)) | get_logger().info("process {} generating {}".format(self._pid, format)) | ||||
self._dump_callback(path, format) | self._dump_callback(path, format) | ||||
get_logger().info("profiling results written to {}".format(path)) | get_logger().info("profiling results written to {}".format(path)) | ||||
if os.path.getsize(path) > 64 * 1024 * 1024: | |||||
get_logger().warning( | |||||
"profiling results too large, maybe you are profiling multi iters," | |||||
"consider attach profiler in each iter separately" | |||||
) | |||||
self._dump_callback = None | self._dump_callback = None | ||||
_living_profilers.remove(self) | _living_profilers.remove(self) | ||||
@@ -9,9 +9,8 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#pragma GCC diagnostic ignored "-Wmissing-field-initializers" | |||||
#include "./grad.h" | #include "./grad.h" | ||||
#include "megbrain/imperative/backward_graph_opt.h" | #include "megbrain/imperative/backward_graph_opt.h" | ||||
#include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
#include "megbrain/imperative/proxy_graph_detail.h" | #include "megbrain/imperative/proxy_graph_detail.h" | ||||
@@ -19,465 +18,19 @@ | |||||
#include "range/v3/all.hpp" | #include "range/v3/all.hpp" | ||||
#include "./transformation.h" | |||||
namespace py = pybind11; | namespace py = pybind11; | ||||
namespace views = ranges::views; | namespace views = ranges::views; | ||||
namespace mgb::imperative::python { | namespace mgb::imperative::python { | ||||
using scoped_disable = ApplyContext::scoped_disable; | |||||
using Flags = Tensor::Flags; | |||||
namespace { | namespace { | ||||
struct GradSlotWeakPtr { | |||||
std::weak_ptr<GradFn> grad_fn; | |||||
size_t idx; | |||||
}; | |||||
std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph( | |||||
ApplyContext& ctx, const apply_result_t& outputs) { | |||||
// hash | |||||
using OptimizedBackwardGraphCache = OpMethResultCache< | |||||
std::shared_ptr<OptimizedBackwardGraphResult>, SmallVector<bool>>; | |||||
thread_local OptimizedBackwardGraphCache cache; | |||||
decltype(cache)::key_t cache_key{ctx.op}; | |||||
SmallVector<LogicalTensorDesc>& input_descs = cache_key.inputs; | |||||
SmallVector<bool>& input_requires_grad = std::get<0>(cache_key.extras); | |||||
input_descs.resize(ctx.nargs); | |||||
input_requires_grad.resize(ctx.nargs); | |||||
for (size_t i = 0; i < ctx.nargs; ++i) { | |||||
input_descs[i].layout.dtype = ctx.args[i]->dtype(); | |||||
input_descs[i].comp_node = ctx.args[i]->comp_node(); | |||||
input_requires_grad[i] = python::input_requires_grad(ctx, i); | |||||
} | |||||
auto iter = cache.find(cache_key); | |||||
if (iter != cache.end()) { | |||||
return iter->second; | |||||
} | |||||
// slow path | |||||
SmallVector<bool> output_has_grad(outputs.size(), true); | |||||
std::shared_ptr<OptimizedBackwardGraphResult> ret; | |||||
auto bg = OpDef::make_backward_graph( | |||||
*ctx.op, input_descs, input_requires_grad, output_has_grad); | |||||
if (!bg.graph.empty()) { | |||||
ret = std::make_shared<OptimizedBackwardGraphResult>(bg); | |||||
} | |||||
cache.emplace(cache_key, ret); | |||||
return ret; | |||||
std::unordered_map<std::shared_ptr<GradKey>, GradKeyWrapper*> grad_key_map; | |||||
} | } | ||||
struct BackwardGraphWithClosure { | |||||
std::shared_ptr<OptimizedBackwardGraphResult> backward_graph; | |||||
SmallVector<std::shared_ptr<Tensor>> closure; | |||||
size_t output_mask_offset; | |||||
size_t grad_mask_offset; | |||||
BackwardGraphWithClosure( | |||||
std::shared_ptr<OptimizedBackwardGraphResult> backward_graph_, | |||||
ApplyContext& ctx, const apply_result_t& outputs) | |||||
: backward_graph(backward_graph_), | |||||
output_mask_offset(ctx.nargs), | |||||
grad_mask_offset(ctx.nargs + outputs.size()) { | |||||
// save_for_backward[0:nargs]: | |||||
// whether input is kept for backward | |||||
// | |||||
// save_for_backward[nargs:nargs+outputs.size()]: | |||||
// whether output is kept for backward | |||||
// | |||||
// save_for_backward[-outputs.size():]: | |||||
// whether gradient of output can propagate to any input | |||||
// | |||||
// Example: | |||||
// perform c = a * b, with a.requires_grad == True and | |||||
// b.requires_grad == False, save_for_backward = [0, 1, 0, 1] | |||||
auto& save_for_backward = backward_graph->save_for_backward; | |||||
mgb_assert(save_for_backward.size() == ctx.nargs + 2 * outputs.size()); | |||||
size_t count = std::count_if( | |||||
save_for_backward.begin(), save_for_backward.end(), ranges::identity{}); | |||||
if (!backward_graph->precomp.empty()) { | |||||
auto&& irng = ranges::span(ctx.args, ctx.nargs); | |||||
auto&& orng = views::transform(outputs, [](auto&& i) { return i.get(); }); | |||||
auto precomp = apply(backward_graph->precomp, views::concat(irng, orng)); | |||||
closure.reserve(precomp.size() + count); | |||||
std::copy(precomp.begin(), precomp.end(), std::back_inserter(closure)); | |||||
} else { | |||||
closure.reserve(count); | |||||
} | |||||
for (size_t i = 0; i < ctx.nargs; ++i) { | |||||
if (save_for_backward[i]) { | |||||
closure.push_back(ctx.args[i]->shared_from_this()); | |||||
} | |||||
} | |||||
for (size_t i = 0; i < outputs.size(); ++i) { | |||||
if (save_for_backward[ctx.nargs + i]) { | |||||
closure.push_back(outputs[i]); | |||||
} | |||||
} | |||||
} | |||||
template <typename T, typename R> | |||||
void operator()(BackwardContext&, T&& grads, R&& receiver) { | |||||
Tensor* args[closure.size() + grads.size()]; | |||||
size_t nargs = 0; | |||||
for (auto&& t : closure) { | |||||
args[nargs++] = t.get(); | |||||
} | |||||
bool null_grad = false; | |||||
for (size_t i = 0; i < grads.size(); ++i) { | |||||
if (backward_graph->save_for_backward[grad_mask_offset + i]) { | |||||
if (grads[i]) { | |||||
if (null_grad) { | |||||
PyErr_SetString(PyExc_NotImplementedError, "report to devs"); | |||||
throw py::error_already_set(); | |||||
} | |||||
args[nargs++] = grads[i]; | |||||
} else { | |||||
null_grad = true; | |||||
} | |||||
} | |||||
} | |||||
if (null_grad) | |||||
return; | |||||
auto igrads = apply(backward_graph->backward, args, nargs); | |||||
auto&& it = igrads.begin(); | |||||
for (auto [i, p] : views::enumerate(backward_graph->input_has_grad)) { | |||||
if (p) { | |||||
receiver(i, std::move(*it)); | |||||
++it; | |||||
} | |||||
} | |||||
} | |||||
bool input_has_grad(size_t i) { return backward_graph->input_has_grad[i]; } | |||||
bool output_requires_grad(size_t i) { | |||||
return backward_graph->save_for_backward[grad_mask_offset + i]; | |||||
} | |||||
bool output_captured(size_t i) { | |||||
return backward_graph->save_for_backward[output_mask_offset + i]; | |||||
} | |||||
}; | |||||
struct PythonBackward { | |||||
py::object pyfunc; | |||||
size_t input_size; | |||||
PythonBackward(py::object f, size_t nin) : pyfunc(f), input_size(nin) {} | |||||
template <typename T, typename R> | |||||
void operator()(BackwardContext& ctx, T&& grads, R&& receiver) { | |||||
auto args = py::tuple(grads.size()); | |||||
for (size_t i = 0; i < grads.size(); ++i) { | |||||
auto&& g = grads[i]; | |||||
args[i] = g ? ctx.wrap_tensor(g) : py::none(); | |||||
} | |||||
auto input_grads = py::reinterpret_steal<py::object>( | |||||
PyObject_Call(pyfunc.ptr(), args.ptr(), nullptr)); | |||||
if (!input_grads) | |||||
throw py::error_already_set(); | |||||
if (input_grads.is_none()) | |||||
return; | |||||
if (auto* tw = TensorWrapper::try_cast(input_grads.ptr())) { | |||||
if (input_size != 1) { | |||||
throw py::value_error( | |||||
"custom grad rule returned wrong number of grads"); | |||||
} | |||||
if (!ctx.pytype) { | |||||
ctx.pytype = Py_TYPE(input_grads.ptr()); | |||||
} | |||||
receiver(0, tw->m_tensor); | |||||
return; | |||||
} | |||||
if (py::len(input_grads) != input_size) { | |||||
throw py::value_error("custom grad rule returned wrong number of grads"); | |||||
} | |||||
for (auto [i, g] : views::enumerate(input_grads)) { | |||||
if (g.is_none()) | |||||
continue; | |||||
auto* tw = TensorWrapper::try_cast(g.ptr()); | |||||
if (!tw) { | |||||
throw py::type_error("custom grad rule returned non-tensor"); | |||||
} | |||||
if (!ctx.pytype) { | |||||
ctx.pytype = Py_TYPE(g.ptr()); | |||||
} | |||||
receiver(i, tw->m_tensor); | |||||
} | |||||
} | |||||
static constexpr bool input_has_grad(size_t) { return true; } | |||||
static constexpr bool output_requires_grad(size_t) { return true; } | |||||
static constexpr bool output_captured(size_t) { return true; } | |||||
}; | |||||
} // namespace | |||||
struct GradProducerRecord : intrusive_list::Node<GradProducerRecord> { | |||||
using Base = intrusive_list::Node<GradProducerRecord>; | |||||
GradProducerRecord() = default; | |||||
GradProducerRecord(GradProducerRecord::head_t& head) | |||||
: Base(intrusive_list::after_t{}, head) {} | |||||
// GradProducerRecord(GradProducerRecord&&) = default; | |||||
// GradProducerRecord& operator=(GradProducerRecord&) = default; | |||||
// GradProducerRecord& operator=(GradProducerRecord&&) = default; | |||||
}; | |||||
struct GradSlot { | |||||
std::shared_ptr<Tensor> grad; | |||||
py::object callback; | |||||
GradProducerRecord::head_t producer_head; | |||||
}; | |||||
struct GradSlotProducerPtr : GradSlotPtr { | |||||
GradProducerRecord producer_record; | |||||
GradSlotProducerPtr() = default; | |||||
GradSlotProducerPtr(GradInfo& info) | |||||
: GradSlotPtr(info), producer_record(info->producer_head) {} | |||||
}; | |||||
struct GradFn : std::enable_shared_from_this<GradFn> { | |||||
static MemPool<GradFn> pool; | |||||
std::weak_ptr<GradKey> key; | |||||
// slots for receiving and accumulating grads | |||||
// same length as outputs (of forward op) | |||||
SmallVector<GradSlot> slots; | |||||
// where to send and accumulate grads | |||||
// same length as inputs (of forward op) | |||||
SmallVector<GradSlotProducerPtr> dsts; | |||||
// encapsules actual function to compute gradient | |||||
std::variant< | |||||
std::monostate, BackwardGraphWithClosure, PythonBackward, CustomBackward> | |||||
backward; | |||||
// a flag used during backward | |||||
bool in_ref_keeper = false; | |||||
static void deleter(GradFn* ptr) { pool.free(ptr); } | |||||
static std::shared_ptr<GradFn> make() { | |||||
return std::shared_ptr<GradFn>(pool.alloc(), &deleter); | |||||
} | |||||
void clear() { | |||||
key.reset(); | |||||
slots.clear(); | |||||
dsts.clear(); | |||||
backward.emplace<std::monostate>(); | |||||
} | |||||
}; | |||||
GradSlotPtr::operator bool() const { | |||||
return bool(grad_fn); | |||||
} | |||||
GradSlot* GradSlotPtr::operator->() { | |||||
return &grad_fn->slots[idx]; | |||||
} | |||||
namespace { | |||||
class GradFnHelper { | |||||
std::shared_ptr<GradFn> grad_fn; | |||||
GradFn* get() { | |||||
if (!grad_fn) { | |||||
grad_fn = std::make_shared<GradFn>(); | |||||
} | |||||
return grad_fn.get(); | |||||
} | |||||
friend apply_result_t imperative::python::apply_grad(ApplyContext&); | |||||
public: | |||||
template <typename T, typename... Args> | |||||
auto& emplace(Args&&... args) { | |||||
return get()->backward.emplace<T>(std::forward<Args>(args)...); | |||||
} | |||||
void reset() { grad_fn = nullptr; } | |||||
}; | |||||
apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { | |||||
// copy inputs first, or trace will make InputNodes for each usage | |||||
ApplyContext ctx_dup = ctx; | |||||
SmallVector<std::shared_ptr<Tensor>> inputs_copy; | |||||
SmallVector<Tensor*> inputs_copy_weak; | |||||
for (size_t i = 0; i < ctx.nargs; ++i) { | |||||
Tensor* input = ctx.args[i]; | |||||
inputs_copy.push_back(python::apply(FastpathCopy::make(), input)[0]); | |||||
inputs_copy_weak.push_back(inputs_copy.back().get()); | |||||
inputs_copy.back()->m_grad_info_dict = ctx.args[i]->m_grad_info_dict; | |||||
if (input->m_flags & Flags::GRAD) { | |||||
inputs_copy.back()->m_flags |= Flags::GRAD; | |||||
} | |||||
} | |||||
ctx_dup.args = inputs_copy_weak.data(); | |||||
auto outputs = apply(ctx_dup); | |||||
auto backward_graph = make_backward_graph(ctx_dup, outputs); | |||||
if (!backward_graph) { | |||||
return outputs; | |||||
} | |||||
ret_grad_fn.emplace<BackwardGraphWithClosure>( | |||||
std::move(backward_graph), ctx_dup, outputs); | |||||
return outputs; | |||||
} | |||||
apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { | |||||
auto* op = ctx.op->try_cast_final<GenericPyOp>(); | |||||
py::tuple pyin(ctx.nargs); | |||||
for (size_t i = 0; i < ctx.nargs; ++i) { | |||||
pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this()); | |||||
} | |||||
auto grad_rule = py::getattr(op->obj, "_grad_rule"); | |||||
auto pyret = py::reinterpret_steal<py::object>( | |||||
PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr)); | |||||
if (!pyret) | |||||
throw py::error_already_set(); | |||||
auto [outputs, backward] = py::cast<std::tuple<py::object, py::function>>(pyret); | |||||
ret_grad_fn.emplace<PythonBackward>(std::move(backward), ctx.nargs); | |||||
if (auto* tw = TensorWrapper::try_cast(outputs.ptr())) { | |||||
return {tw->m_tensor}; | |||||
} | |||||
apply_result_t ret; | |||||
ret.reserve(py::len(outputs)); | |||||
for (auto&& i : outputs) { | |||||
auto* tw = TensorWrapper::try_cast(i.ptr()); | |||||
mgb_assert(tw); | |||||
ret.push_back(tw->m_tensor); | |||||
} | |||||
return ret; | |||||
} | |||||
} // namespace | |||||
apply_result_t apply_grad(ApplyContext& ctx) { | |||||
std::unordered_set<std::shared_ptr<GradKey>> grad_keys; | |||||
for (size_t i = 0; i < ctx.nargs; ++i) { | |||||
auto* tensor = ctx.args[i]; | |||||
if (!tensor->m_grad_info_dict.empty()) { | |||||
size_t grad_cnt = 0; | |||||
for (auto&& grad_info : tensor->m_grad_info_dict) { | |||||
auto input_grad_key = grad_info.grad_fn->key.lock(); | |||||
if (input_grad_key && input_grad_key->active && | |||||
!input_grad_key->is_blocked()) { | |||||
grad_keys.insert(input_grad_key); | |||||
grad_cnt++; | |||||
} | |||||
} | |||||
if (!grad_cnt) { | |||||
tensor->m_flags &= ~Flags::GRAD; | |||||
} | |||||
} else { | |||||
tensor->m_flags &= ~Flags::GRAD; | |||||
} | |||||
} | |||||
ctx.flags &= ~Flags::GRAD; | |||||
if (grad_keys.empty()) { | |||||
return apply(ctx); | |||||
} else if (grad_keys.size() > 1 && !GradKey::allow_higher_order_directive) { | |||||
PyErr_SetString( | |||||
PyExc_NotImplementedError, | |||||
"second order directive not enabled, please call " | |||||
"'megengine.experimental.enable_higher_order_directive'"); | |||||
throw pyext17::py_err_set(); | |||||
} | |||||
GradFnHelper grad_fn_holder; | |||||
auto outputs = [&]() { | |||||
auto _ = scoped_disable(Flags::GRAD); | |||||
if (ctx.op->same_type<GenericPyOp>()) { | |||||
return python_grad_rule(ctx, grad_fn_holder); | |||||
} | |||||
auto&& registry = grad_rule_registry(); | |||||
auto&& it = registry.find(ctx.op->dyn_typeinfo()); | |||||
if (it != registry.end()) { | |||||
auto&& maker = grad_fn_holder.emplace<CustomBackward>().maker(ctx); | |||||
if (auto ret = it->second(ctx, maker)) { | |||||
maker.finalize(); | |||||
return *ret; | |||||
} | |||||
grad_fn_holder.reset(); | |||||
} | |||||
return backward_graph_grad_rule(ctx, grad_fn_holder); | |||||
}(); | |||||
if (!grad_fn_holder.grad_fn) { | |||||
return outputs; | |||||
} | |||||
for (auto&& grad_key : grad_keys) { | |||||
auto grad_fn = std::make_shared<GradFn>(); | |||||
grad_fn->backward = grad_fn_holder.grad_fn->backward; | |||||
grad_fn->key = grad_key; | |||||
grad_fn->slots.resize(outputs.size()); | |||||
grad_fn->dsts.reserve(ctx.nargs); | |||||
std::visit( | |||||
[&](auto& backward) { | |||||
using T = std::decay_t<decltype(backward)>; | |||||
if constexpr (std::is_same_v<T, std::monostate>) { | |||||
mgb_assert(0); | |||||
} else { | |||||
for (size_t i = 0; i < ctx.nargs; ++i) { | |||||
if (backward.input_has_grad(i) && | |||||
input_requires_grad(ctx, i) && | |||||
ctx.args[i]->m_grad_info_dict.count(grad_key.get())) { | |||||
auto& input_grad_info = | |||||
ctx.args[i]->m_grad_info_dict.at( | |||||
grad_key.get()); | |||||
grad_fn->dsts.emplace_back(input_grad_info); | |||||
// register as grad producer | |||||
grad_fn->dsts.back().producer_record.insert_after( | |||||
input_grad_info->producer_head); | |||||
} else { | |||||
grad_fn->dsts.emplace_back(); | |||||
} | |||||
} | |||||
for (size_t i = 0; i < outputs.size(); ++i) { | |||||
if (backward.output_requires_grad(i)) { | |||||
if (backward.output_captured(i)) { | |||||
// avoid reference cycle [Tensor <-> GradFn] | |||||
static std::shared_ptr<OpDef> op = | |||||
std::make_shared<FastpathCopy>(); | |||||
outputs[i] = python::apply(op, outputs[i])[0]; | |||||
} | |||||
// populate grad info of output tensor | |||||
auto& grad_info = | |||||
outputs[i]->m_grad_info_dict[grad_key.get()]; | |||||
grad_info.grad_fn = grad_fn; | |||||
grad_info.idx = i; | |||||
grad_info.insert_after(grad_key->free_vars_head); | |||||
outputs[i]->m_flags |= Flags::GRAD; | |||||
} | |||||
} | |||||
} | |||||
}, | |||||
grad_fn->backward); | |||||
// record forward history | |||||
grad_key->tape.emplace_back(grad_fn); | |||||
} | |||||
return outputs; | |||||
} | |||||
PyObject* GradKeyWrapper::get_priority() { | |||||
return py::cast(m_key->priority).release().ptr(); | |||||
} | |||||
void GradKeyWrapper::set_priority(pybind11::handle priority) { | |||||
m_key->priority = py::cast<int>(priority); | |||||
GradKeyWrapper::GradKeyWrapper() : m_key(std::make_shared<GradKey>()) { | |||||
grad_key_map[m_key] = this; | |||||
} | } | ||||
void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) { | void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) { | ||||
@@ -488,157 +41,59 @@ void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) { | |||||
if (!tw) { | if (!tw) { | ||||
throw py::type_error("argument 1 must be Tensor"); | throw py::type_error("argument 1 must be Tensor"); | ||||
} | } | ||||
auto* tensor = tw->m_tensor.get(); | |||||
py::object callback; | py::object callback; | ||||
if (args[1] != Py_None) { | if (args[1] != Py_None) { | ||||
callback = py::reinterpret_borrow<py::object>(args[1]); | callback = py::reinterpret_borrow<py::object>(args[1]); | ||||
} | } | ||||
m_key->attach(tensor, std::move(callback)); | |||||
} | |||||
//! GradKey is weakly refered by tensor->m_grad_info.grad_fn->key after attach | |||||
void GradKey::attach(Tensor* tensor, pybind11::object callback) { | |||||
if (!active) { | |||||
throw py::value_error("grad key finalized"); | |||||
} | |||||
if (tensor->m_grad_info_dict.count(this)) { | |||||
if (tensor->m_grad_info_dict.at(this)->callback) { | |||||
throw py::value_error("callback already set on this tensor"); | |||||
GenericFunction generic_callback = | |||||
[=](Span<ValueRef> inputs) -> std::vector<ValueRef> { | |||||
mgb_assert(inputs.size() == 1); | |||||
if (callback) { | |||||
callback(TensorWrapper::make(py_tensor_type, inputs[0])); | |||||
} | } | ||||
} else { | |||||
auto& grad_info = tensor->m_grad_info_dict[this]; | |||||
grad_info.idx = 0; | |||||
auto& grad_fn = grad_info.grad_fn; | |||||
grad_fn = std::make_shared<GradFn>(); | |||||
grad_fn->key = shared_from_this(); | |||||
grad_fn->slots.resize(1); | |||||
grad_info.insert_after(free_vars_head); | |||||
tensor->m_flags |= Flags::GRAD; | |||||
} | |||||
tensor->m_grad_info_dict.at(this).grad_fn->slots[0].callback = std::move(callback); | |||||
} | |||||
template <typename T> | |||||
void accum_grad(std::shared_ptr<Tensor>& grad, T&& delta) { | |||||
if (!grad) { | |||||
grad = std::forward<T>(delta); | |||||
return; | |||||
} | |||||
static std::shared_ptr<OpDef> op = | |||||
std::shared_ptr<OpDef>(new Elemwise(Elemwise::Mode::ADD)); | |||||
grad = apply(op, grad, std::forward<T>(delta))[0]; | |||||
return {}; | |||||
}; | |||||
tw->m_tensor->reset(imperative::apply( | |||||
AttachGrad(m_key), tw->m_tensor->data(), | |||||
FunctionValue::make(generic_callback))[0]); | |||||
} | } | ||||
void GradKey::backward( | |||||
std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) { | |||||
if (!active) { | |||||
throw py::value_error("finalized"); | |||||
void GradKeyWrapper::backward(GradKeyWrapper* self, py::list tensors, py::list grads) { | |||||
std::vector<ValueRef> args; | |||||
mgb_assert(tensors.size() == grads.size()); | |||||
for (auto&& tensor : tensors) { | |||||
args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data()); | |||||
} | } | ||||
if (tensors.size() != grads.size()) { | |||||
throw py::value_error("tensor and grad size mismatch"); | |||||
for (auto&& grad : grads) { | |||||
args.push_back(TensorWrapper::try_cast(grad.ptr())->m_tensor->data()); | |||||
} | } | ||||
// this GradKey is marked inactive here | |||||
active = false; | |||||
struct CleanupGuard { | |||||
GradKey* owner; | |||||
size_t priority_backup; | |||||
CleanupGuard(GradKey* this_) : owner(this_) { | |||||
priority_backup = sm_min_priority; | |||||
sm_min_priority = owner->priority + 1; | |||||
} | |||||
~CleanupGuard() { | |||||
owner->cleanup(); | |||||
sm_min_priority = priority_backup; | |||||
} | |||||
} _cleanup_guard(this); | |||||
if (tape.empty()) | |||||
return; | |||||
BackwardContext bctx; | |||||
if (!grads.empty()) { | |||||
bctx.pytype = Py_TYPE(grads[0]->self().ptr()); | |||||
} | |||||
for (size_t i = 0; i < tensors.size(); ++i) { | |||||
if (tensors[i]->m_tensor->m_grad_info_dict.count(this) == 0) { | |||||
continue; | |||||
} | |||||
auto& grad_info = tensors[i]->m_tensor->m_grad_info_dict.at(this); | |||||
grad_info->grad = grads[i]->m_tensor; | |||||
} | |||||
std::vector<std::shared_ptr<GradFn>> ref_keeper; | |||||
ref_keeper.reserve(tape.size()); | |||||
// back-propagation in reverse order | |||||
for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) { | |||||
auto&& grad_fn = tape[k].lock(); | |||||
if (!grad_fn) | |||||
continue; | |||||
auto grad_receiver = [&](size_t i, auto&& g) { | |||||
auto& dst = grad_fn->dsts[i]; | |||||
if (dst) { | |||||
accum_grad(dst->grad, std::forward<decltype(g)>(g)); | |||||
} | |||||
}; | |||||
std::visit( | |||||
[&](auto&& backward) { | |||||
using T = std::decay_t<decltype(backward)>; | |||||
if constexpr (std::is_same_v<T, std::monostate>) { | |||||
mgb_assert(0); | |||||
} else { | |||||
auto&& grads = views::transform( | |||||
grad_fn->slots, | |||||
[](auto&& slot) { return slot.grad.get(); }); | |||||
backward( | |||||
bctx, std::forward<decltype(grads)>(grads), | |||||
grad_receiver); | |||||
} | |||||
}, | |||||
grad_fn->backward); | |||||
for (auto&& dst : grad_fn->dsts) { | |||||
if (!dst.grad_fn) | |||||
continue; | |||||
if (!dst.grad_fn->in_ref_keeper) { | |||||
// after grad_fn is cleared, refcnt of subsequent grad_fn | |||||
// could drop to 0 | |||||
dst.grad_fn->in_ref_keeper = true; | |||||
ref_keeper.push_back(dst.grad_fn); | |||||
} | |||||
if (!dst.producer_record.next && dst->callback && dst->grad) { | |||||
// I'm the last grad producer, invoke callback | |||||
dst->callback(bctx.wrap_tensor(dst->grad)); | |||||
} | |||||
} | |||||
grad_fn->clear(); | |||||
} // finish tape loop | |||||
imperative::apply(GradBackward(self->m_key), {args.data(), args.size()}); | |||||
} | } | ||||
void GradKey::cleanup() { | |||||
active = false; | |||||
tape.clear(); | |||||
for (intrusive_list::Iterator it(free_vars_head); it;) { | |||||
it->grad_fn.reset(); | |||||
(it++)->unlink(); | |||||
pybind11::function GradKeyWrapper::get_backward_closure( | |||||
GradKeyWrapper* self, py::list tensors) { | |||||
std::vector<ValueRef> args; | |||||
for (auto&& tensor : tensors) { | |||||
args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data()); | |||||
} | } | ||||
} | |||||
void GradKeyWrapper::backward( | |||||
std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) { | |||||
m_key->backward(std::move(tensors), std::move(grads)); | |||||
auto closure = imperative::apply(GetBackwardColsure(self->m_key), args)[0] | |||||
.as<FunctionValue>(); | |||||
auto py_function = [closure](std::vector<TensorWrapper*> tensors) { | |||||
std::vector<ValueRef> args; | |||||
for (auto* tw : tensors) { | |||||
args.push_back(tw->m_tensor->data()); | |||||
} | |||||
(*closure)(args); | |||||
}; | |||||
return pybind11::cpp_function(py_function); | |||||
} | } | ||||
PyObject* GradKeyWrapper::get_name() { | PyObject* GradKeyWrapper::get_name() { | ||||
return py::cast(m_key->name).release().ptr(); | |||||
return py::cast(m_key->name()).release().ptr(); | |||||
} | } | ||||
void GradKeyWrapper::set_name(py::handle name) { | void GradKeyWrapper::set_name(py::handle name) { | ||||
m_key->name = py::cast<std::string>(name); | |||||
m_key->name(py::cast<std::string>(name)); | |||||
} | } | ||||
PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) { | PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) { | ||||
@@ -651,60 +106,39 @@ PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) { | |||||
PyErr_SetString(PyExc_TypeError, "expect Tensor"); | PyErr_SetString(PyExc_TypeError, "expect Tensor"); | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
if (tw->m_tensor->m_grad_info_dict.count(m_key.get())) { | |||||
if (imperative::apply(IsAttachedTo(m_key), tw->m_tensor->data())[0] | |||||
.cast<BoolValue>()) { | |||||
Py_RETURN_TRUE; | Py_RETURN_TRUE; | ||||
} | } | ||||
Py_RETURN_FALSE; | Py_RETURN_FALSE; | ||||
} | } | ||||
int GradKey::sm_min_priority = std::numeric_limits<int>::min(); | |||||
GradKey::~GradKey() { | |||||
cleanup(); | |||||
void GradKeyWrapper::enter() { | |||||
m_transformation = std::make_shared<GradTransformation>(m_key); | |||||
TransformationManager::get_instance().register_at<TransformationManager::Grad>( | |||||
m_transformation); | |||||
} | } | ||||
std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry() { | |||||
static std::unordered_map<Typeinfo*, GradRuleFn> registry; | |||||
return registry; | |||||
void GradKeyWrapper::exit() { | |||||
TransformationManager::get_instance().unregister<TransformationManager::Grad>( | |||||
m_transformation); | |||||
m_transformation.reset(); | |||||
} | } | ||||
void GradInfoCollection::_shrink() { | |||||
auto pred = [](GradInfo& info) { | |||||
return !(info.grad_fn) || info.grad_fn->key.expired(); | |||||
}; | |||||
auto iter = std::remove_if(m_storage.begin(), m_storage.end(), pred); | |||||
m_storage.erase(iter, m_storage.end()); | |||||
void GradKeyWrapper::suppress() { | |||||
m_transformation->suppress(); | |||||
} | } | ||||
bool GradInfoCollection::contains(GradKey* key) { | |||||
_shrink(); | |||||
for (auto&& grad_info : m_storage) { | |||||
if (grad_info.grad_fn->key.lock().get() == key) { | |||||
return true; | |||||
} | |||||
} | |||||
return false; | |||||
void GradKeyWrapper::resume() { | |||||
m_transformation->resume(); | |||||
} | } | ||||
GradInfo& GradInfoCollection::operator[](GradKey* key) { | |||||
_shrink(); | |||||
for (auto&& grad_info : m_storage) { | |||||
if (grad_info.grad_fn->key.lock().get() == key) { | |||||
return grad_info; | |||||
} | |||||
} | |||||
m_storage.emplace_back(); | |||||
return m_storage.back(); | |||||
GradKeyWrapper* GradKeyWrapper::get(std::shared_ptr<GradKey> key) { | |||||
return grad_key_map.at(key); | |||||
} | } | ||||
GradInfo& GradInfoCollection::at(GradKey* key) { | |||||
_shrink(); | |||||
for (auto&& grad_info : m_storage) { | |||||
if (grad_info.grad_fn->key.lock().get() == key) { | |||||
return grad_info; | |||||
} | |||||
} | |||||
mgb_assert(false); | |||||
GradKeyWrapper::~GradKeyWrapper() { | |||||
grad_key_map.erase(m_key); | |||||
} | } | ||||
} // namespace mgb::imperative::python | } // namespace mgb::imperative::python |
@@ -12,166 +12,40 @@ | |||||
#pragma once | #pragma once | ||||
#include "./tensor.h" | #include "./tensor.h" | ||||
#include "megbrain/imperative/ops/utility.h" | #include "megbrain/imperative/ops/utility.h" | ||||
#include "megbrain/imperative/transformations/grad.h" | |||||
#include "megbrain/utils/small_vector.h" | |||||
#include <megbrain/utils/small_vector.h> | |||||
#include <memory> | #include <memory> | ||||
#include <optional> | #include <optional> | ||||
namespace mgb::imperative::python { | namespace mgb::imperative::python { | ||||
apply_result_t apply_grad(ApplyContext& ctx); | |||||
struct GradKey : std::enable_shared_from_this<GradKey>, NonCopyableObj { | |||||
std::string name; | |||||
bool active = true; | |||||
GradInfo::head_t free_vars_head; | |||||
std::vector<std::weak_ptr<GradFn>> tape; | |||||
int priority = 0; | |||||
~GradKey(); | |||||
void attach(Tensor* tensor, pybind11::object callback); | |||||
void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>); | |||||
void cleanup(); | |||||
bool is_blocked() const { return priority < sm_min_priority; } | |||||
inline static bool allow_higher_order_directive = false; | |||||
private: | |||||
static int sm_min_priority; | |||||
}; | |||||
struct GradKeyWrapper { | |||||
struct GradKeyWrapper : NonCopyableObj { | |||||
using wrap_t = pyext17::wrap<GradKeyWrapper>; | using wrap_t = pyext17::wrap<GradKeyWrapper>; | ||||
static constexpr auto tp_name = pybind11::detail::_("GradKey"); | static constexpr auto tp_name = pybind11::detail::_("GradKey"); | ||||
std::shared_ptr<GradKey> m_key; | std::shared_ptr<GradKey> m_key; | ||||
std::shared_ptr<GradTransformation> m_transformation; | |||||
inline GradKeyWrapper() : m_key(std::make_shared<GradKey>()) {} | |||||
GradKeyWrapper(); | |||||
PyObject* get_name(); | PyObject* get_name(); | ||||
void set_name(pybind11::handle name); | void set_name(pybind11::handle name); | ||||
PyObject* get_priority(); | |||||
void set_priority(pybind11::handle priority); | |||||
void attach(PyObject* const* args, size_t nargs); | void attach(PyObject* const* args, size_t nargs); | ||||
void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>); | |||||
static void backward(GradKeyWrapper* self, pybind11::list, pybind11::list); | |||||
static pybind11::function get_backward_closure( | |||||
GradKeyWrapper* self, pybind11::list); | |||||
PyObject* is_attached_to(PyObject* const* args, size_t nargs); | PyObject* is_attached_to(PyObject* const* args, size_t nargs); | ||||
void enter(); | |||||
void exit(); | |||||
void suppress(); | |||||
void resume(); | |||||
static GradKeyWrapper* get(std::shared_ptr<GradKey> key); | |||||
~GradKeyWrapper(); | |||||
}; | }; | ||||
struct BackwardContext { | |||||
PyTypeObject* pytype = nullptr; | |||||
auto wrap_tensor(std::shared_ptr<Tensor> t) { | |||||
if (pytype) { | |||||
return TensorWrapper::make(pytype, std::move(t)); | |||||
} | |||||
return TensorWrapper::make(std::move(t)); | |||||
} | |||||
auto wrap_tensor(Tensor* t) { return wrap_tensor(t->shared_from_this()); } | |||||
}; | |||||
struct CustomBackward { | |||||
using BackwardFn = | |||||
std::function<apply_result_t(BackwardContext&, Tensor* const*, size_t)>; | |||||
BackwardFn m_backward; | |||||
SmallVector<bool, 8> m_input_has_grad; | |||||
struct OutputAttr { | |||||
bool requires_grad = true, captured = true; | |||||
}; | |||||
SmallVector<OutputAttr> m_output_attrs; | |||||
public: | |||||
template <typename T, typename R> | |||||
void operator()(BackwardContext& ctx, T&& grads, R&& receiver) { | |||||
size_t nargs = grads.size(); | |||||
Tensor* args[nargs]; | |||||
for (size_t i = 0; i < nargs; ++i) { | |||||
args[i] = grads[i]; | |||||
} | |||||
auto ret = m_backward(ctx, args, nargs); | |||||
for (size_t i = 0; i < ret.size(); ++i) { | |||||
if (auto&& t = ret[i]) { | |||||
receiver(i, std::move(t)); | |||||
} | |||||
} | |||||
} | |||||
bool input_has_grad(size_t i) { return m_input_has_grad[i]; } | |||||
bool output_requires_grad(size_t i) { return m_output_attrs[i].requires_grad; } | |||||
bool output_captured(size_t i) { return m_output_attrs[i].captured; } | |||||
class Maker { | |||||
bool output_size_set = false, input_has_grad_initialized = false; | |||||
CustomBackward& target; | |||||
ApplyContext& ctx; | |||||
void init_input_has_grad() { | |||||
if (!input_has_grad_initialized) { | |||||
input_has_grad_initialized = true; | |||||
target.m_input_has_grad.resize(ctx.nargs, true); | |||||
} | |||||
} | |||||
public: | |||||
Maker(CustomBackward& target_, ApplyContext& ctx_) | |||||
: target(target_), ctx(ctx_) {} | |||||
template <typename F> | |||||
Maker& backward(F&& f) { | |||||
mgb_assert(!target.m_backward); | |||||
target.m_backward = std::forward<F>(f); | |||||
return *this; | |||||
} | |||||
// mandatory | |||||
Maker& output_size(size_t sz) { | |||||
mgb_assert(!output_size_set); | |||||
output_size_set = true; | |||||
target.m_output_attrs.resize(sz); | |||||
return *this; | |||||
} | |||||
// optional, defaults to all true | |||||
Maker& input_has_grad(size_t i, bool v) { | |||||
init_input_has_grad(); | |||||
target.m_input_has_grad.at(i) = v; | |||||
return *this; | |||||
} | |||||
// optional, defaults to all true | |||||
Maker& output_requires_grad(size_t i, bool v) { | |||||
target.m_output_attrs.at(i).requires_grad = v; | |||||
return *this; | |||||
} | |||||
// optional, defaults to all true | |||||
Maker& output_captured(size_t i, bool v) { | |||||
target.m_output_attrs.at(i).captured = v; | |||||
return *this; | |||||
} | |||||
void finalize() { | |||||
mgb_assert(output_size_set); | |||||
init_input_has_grad(); | |||||
} | |||||
}; | |||||
Maker maker(ApplyContext& ctx) { return {*this, ctx}; } | |||||
}; | |||||
using GradRuleFn = std::function<std::optional<apply_result_t>( | |||||
ApplyContext&, CustomBackward::Maker&)>; | |||||
std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry(); | |||||
inline bool input_requires_grad(const ApplyContext& ctx, size_t i) { | |||||
return !ctx.args[i]->m_grad_info_dict.empty(); | |||||
} | |||||
struct GradRuleFallback : std::exception {}; | |||||
template <typename T> | |||||
bool register_grad_rule(Typeinfo* typeinfo, T&& rule) { | |||||
return grad_rule_registry().emplace(typeinfo, std::forward<T>(rule)).second; | |||||
} | |||||
} // namespace mgb::imperative::python | } // namespace mgb::imperative::python | ||||
namespace pybind11::detail { | namespace pybind11::detail { | ||||
@@ -1,43 +0,0 @@ | |||||
/** | |||||
* \file imperative/python/src/grad_info.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include <memory> | |||||
#include "./intrusive_list.h" | |||||
namespace mgb::imperative::python { | |||||
struct GradKey; | |||||
struct GradFn; | |||||
struct GradSlot; | |||||
struct GradSlotPtr { | |||||
std::shared_ptr<GradFn> grad_fn; | |||||
size_t idx; | |||||
operator bool() const; | |||||
GradSlot* operator->(); | |||||
}; | |||||
struct GradInfo : GradSlotPtr, | |||||
intrusive_list::Node<GradInfo, intrusive_list::before_t> { | |||||
GradInfo() = default; | |||||
GradInfo(GradInfo&) = default; | |||||
GradInfo(GradInfo&&) = default; | |||||
GradInfo& operator=(GradInfo&) = default; | |||||
GradInfo& operator=(GradInfo&&) = default; | |||||
GradInfo(const GradInfo& rhs) : GradInfo(const_cast<GradInfo&>(rhs)) {} | |||||
GradInfo& operator=(const GradInfo& rhs) { | |||||
return *this = const_cast<GradInfo&>(rhs); | |||||
} | |||||
}; | |||||
} // namespace mgb::imperative::python |
@@ -11,261 +11,334 @@ | |||||
#include "./grad.h" | #include "./grad.h" | ||||
#include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
#include "megbrain/imperative/transformations/grad.h" | |||||
namespace mgb::imperative::python { | namespace mgb::imperative::python { | ||||
class CustomGradMaker { | |||||
bool output_size_set = false, input_has_grad_initialized = false; | |||||
CustomBackward& target; | |||||
size_t nr_inputs; | |||||
void init_input_has_grad() { | |||||
if (!input_has_grad_initialized) { | |||||
input_has_grad_initialized = true; | |||||
target.m_input_has_grad.resize(nr_inputs, true); | |||||
} | |||||
} | |||||
public: | |||||
CustomGradMaker(CustomBackward& target, size_t nr_inputs) | |||||
: target(target), nr_inputs(nr_inputs) {} | |||||
CustomGradMaker& backward(CustomBackward::BackwardFn f) { | |||||
mgb_assert(!target.m_backward); | |||||
target.m_backward = f; | |||||
return *this; | |||||
} | |||||
// mandatory | |||||
CustomGradMaker& output_size(size_t sz) { | |||||
mgb_assert(!output_size_set); | |||||
output_size_set = true; | |||||
target.m_output_attrs.resize(sz); | |||||
return *this; | |||||
} | |||||
// optional, defaults to all true | |||||
CustomGradMaker& input_has_grad(size_t i, bool v) { | |||||
init_input_has_grad(); | |||||
target.m_input_has_grad.at(i) = v; | |||||
return *this; | |||||
} | |||||
// optional, defaults to all true | |||||
CustomGradMaker& output_requires_grad(size_t i, bool v) { | |||||
target.m_output_attrs.at(i).requires_grad = v; | |||||
return *this; | |||||
} | |||||
// optional, defaults to all true | |||||
CustomGradMaker& output_captured(size_t i, bool v) { | |||||
target.m_output_attrs.at(i).captured = v; | |||||
return *this; | |||||
} | |||||
void finalize() { | |||||
mgb_assert(output_size_set); | |||||
init_input_has_grad(); | |||||
} | |||||
}; | |||||
namespace { | namespace { | ||||
std::shared_ptr<Tensor> get_shape(Tensor* x) { | |||||
ValueRef get_shape(ValueRef x) { | |||||
static auto op = GetVarShape::make(); | static auto op = GetVarShape::make(); | ||||
return python::apply(op, x)[0]; | |||||
return imperative::apply(*op, x)[0]; | |||||
} | } | ||||
std::shared_ptr<Tensor> reduce_to(Tensor* x, Tensor* s) { | |||||
ValueRef reduce_to(ValueRef x, ValueRef s) { | |||||
static auto op = Reduce::make(); | static auto op = Reduce::make(); | ||||
return python::apply(op, x, s)[0]; | |||||
return imperative::apply(*op, x, s)[0]; | |||||
} | } | ||||
std::shared_ptr<Tensor> reshape_to(Tensor* x, Tensor* s) { | |||||
ValueRef reshape_to(ValueRef x, ValueRef s) { | |||||
static auto op = Reshape::make(); | static auto op = Reshape::make(); | ||||
return python::apply(op, x, s)[0]; | |||||
return imperative::apply(*op, x, s)[0]; | |||||
} | } | ||||
std::shared_ptr<Tensor> broadcast_to(Tensor* x, Tensor* s) { | |||||
ValueRef broadcast_to(ValueRef x, ValueRef s) { | |||||
static auto op = Broadcast::make(); | static auto op = Broadcast::make(); | ||||
return python::apply(op, x, s)[0]; | |||||
return imperative::apply(*op, x, s)[0]; | |||||
} | } | ||||
std::shared_ptr<Tensor> make_empty_tensor(CompNode cn, Tensor* shape, DType dtype) { | |||||
HostTensorND scalar{cn, {{1}, dtype}}; | |||||
std::memset(scalar.raw_ptr(), 0, dtype.size()); | |||||
interpreter::Interpreter::Handle handle = interpreter_for_py->put(scalar, false); | |||||
auto&& t = std::make_shared<Tensor>(handle); | |||||
auto res = broadcast_to(t.get(), shape); | |||||
ValueRef make_empty_tensor( | |||||
CompNodeValue::ref_t device, ValueRef shape, DTypeValue::ref_t dtype) { | |||||
HostTensorStorage storage(*device); | |||||
storage.ensure_size(dtype->size()); | |||||
std::memset(storage.ptr(), 0, dtype->size()); | |||||
auto t = imperative::apply( | |||||
CreateTensor(CreateTensor::Unique, *device, *dtype, ValueShape()), | |||||
HostStorage::make(storage))[0]; | |||||
auto res = broadcast_to(t, shape); | |||||
return res; | return res; | ||||
} | } | ||||
std::optional<apply_result_t> elemwise_grad_rule( | |||||
ApplyContext& ctx, CustomBackward::Maker& maker) { | |||||
auto& op = ctx.op->cast_final_safe<Elemwise>(); | |||||
if (op.mode == Elemwise::Mode::ADD) { | |||||
mgb_assert(ctx.nargs == 2); | |||||
std::array<std::shared_ptr<Tensor>, 2> input_shapes; | |||||
std::optional<std::vector<ValueRef>> elemwise_grad_rule( | |||||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||||
CustomBackward& backward) { | |||||
auto& elemwise = op.cast_final_safe<Elemwise>(); | |||||
if (elemwise.mode != Elemwise::Mode::ADD) { | |||||
return {}; | |||||
} | |||||
mgb_assert(inputs.size() == 2); | |||||
std::array<ValueRef, 2> input_shapes; | |||||
for (size_t i = 0; i < 2; ++i) { | |||||
if (inputs_require_grad[i]) { | |||||
input_shapes[i] = get_shape(inputs[i]); | |||||
} | |||||
} | |||||
auto maker = CustomGradMaker(backward, inputs.size()); | |||||
maker.output_size(1).output_captured(0, false); | |||||
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { | |||||
mgb_assert(grads.size() == 1); | |||||
ValueRef grad = grads[0]; | |||||
std::vector<ValueRef> ret(2); | |||||
if (!grad) { | |||||
return ret; | |||||
} | |||||
for (size_t i = 0; i < 2; ++i) { | for (size_t i = 0; i < 2; ++i) { | ||||
if (input_requires_grad(ctx, i)) { | |||||
input_shapes[i] = get_shape(ctx.args[i]); | |||||
if (shapes[i]) { | |||||
ret[i] = reduce_to(grad, shapes[i]); | |||||
} | } | ||||
} | } | ||||
maker.output_size(1).output_captured(0, false); | |||||
maker.backward([shapes = std::move(input_shapes)]( | |||||
BackwardContext&, Tensor* const* grads, size_t ngrads) { | |||||
mgb_assert(ngrads == 1); | |||||
Tensor* grad = grads[0]; | |||||
apply_result_t ret(2); | |||||
if (!grad) { | |||||
return ret; | |||||
} | |||||
for (size_t i = 0; i < 2; ++i) { | |||||
if (shapes[i]) { | |||||
ret[i] = reduce_to(grad, shapes[i].get()); | |||||
} | |||||
} | |||||
return ret; | |||||
}); | |||||
return apply(ctx); | |||||
} | |||||
return {}; | |||||
return ret; | |||||
}); | |||||
maker.finalize(); | |||||
return imperative::apply(ApplyOp(op), inputs); | |||||
} | } | ||||
std::optional<apply_result_t> reshape_grad_rule( | |||||
ApplyContext& ctx, CustomBackward::Maker& maker) { | |||||
mgb_assert(ctx.nargs == 2); | |||||
std::array<std::shared_ptr<Tensor>, 2> input_shapes; | |||||
std::optional<std::vector<ValueRef>> reshape_grad_rule( | |||||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||||
CustomBackward& backward) { | |||||
mgb_assert(inputs.size() == 2); | |||||
std::array<ValueRef, 2> input_shapes; | |||||
for (size_t i = 0; i < 2; ++i) { | for (size_t i = 0; i < 2; ++i) { | ||||
if (input_requires_grad(ctx, i)) { | |||||
input_shapes[i] = get_shape(ctx.args[i]); | |||||
if (inputs_require_grad[i]) { | |||||
input_shapes[i] = get_shape(inputs[i]); | |||||
} | } | ||||
} | } | ||||
auto maker = CustomGradMaker(backward, inputs.size()); | |||||
maker.output_size(1).output_captured(0, false); | maker.output_size(1).output_captured(0, false); | ||||
maker.backward([shapes = std::move(input_shapes)]( | |||||
BackwardContext&, Tensor* const* grads, size_t ngrads) { | |||||
mgb_assert(ngrads == 1); | |||||
Tensor* grad = grads[0]; | |||||
apply_result_t ret(2); | |||||
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { | |||||
mgb_assert(grads.size() == 1); | |||||
ValueRef grad = grads[0]; | |||||
std::vector<ValueRef> ret(2); | |||||
if (!grad) { | if (!grad) { | ||||
return ret; | return ret; | ||||
} | } | ||||
for (size_t i = 0; i < 2; ++i) { | for (size_t i = 0; i < 2; ++i) { | ||||
if (shapes[i]) { | if (shapes[i]) { | ||||
ret[i] = reshape_to(grad, shapes[i].get()); | |||||
ret[i] = reshape_to(grad, shapes[i]); | |||||
} | } | ||||
} | } | ||||
return ret; | return ret; | ||||
}); | }); | ||||
return apply(ctx); | |||||
maker.finalize(); | |||||
return imperative::apply(ApplyOp(op), inputs); | |||||
} | } | ||||
std::optional<apply_result_t> subtensor_grad_rule( | |||||
ApplyContext& ctx, CustomBackward::Maker& maker) { | |||||
auto&& op = ctx.op->cast_final_safe<Subtensor>(); | |||||
auto&& grad_op = SetSubtensor::make(op.items); | |||||
SmallVector<std::shared_ptr<Tensor>> inputs; | |||||
if (input_requires_grad(ctx, 0)) { | |||||
inputs.push_back(get_shape(ctx.args[0])); | |||||
for (size_t i = 1; i < ctx.nargs; ++i) { | |||||
inputs.push_back(ctx.args[i]->copy()); | |||||
std::optional<std::vector<ValueRef>> subtensor_grad_rule( | |||||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||||
CustomBackward& backward) { | |||||
auto&& subtensor = op.cast_final_safe<Subtensor>(); | |||||
auto&& grad_op = SetSubtensor::make(subtensor.items); | |||||
SmallVector<ValueRef> inputs2; | |||||
if (inputs_require_grad[0]) { | |||||
inputs2.push_back(get_shape(inputs[0])); | |||||
for (size_t i = 1; i < inputs.size(); ++i) { | |||||
inputs2.push_back(inputs[i]); | |||||
} | } | ||||
} | } | ||||
auto maker = CustomGradMaker(backward, inputs.size()); | |||||
maker.output_size(1).output_captured(0, false); | maker.output_size(1).output_captured(0, false); | ||||
maker.backward([inputs = std::move(inputs), grad_op_ = std::move(grad_op)]( | |||||
BackwardContext&, Tensor* const* grads, size_t ngrads) { | |||||
mgb_assert(ngrads == 1); | |||||
Tensor* grad = grads[0]; | |||||
apply_result_t ret(1); | |||||
maker.backward([inputs = std::move(inputs2), | |||||
grad_op_ = std::move(grad_op)](Span<ValueRef> grads) { | |||||
mgb_assert(grads.size() == 1); | |||||
ValueRef grad = grads[0]; | |||||
std::vector<ValueRef> ret(1); | |||||
if (grad && inputs[0]) { | if (grad && inputs[0]) { | ||||
SmallVector<Tensor*> args_(inputs.size() + 1); | |||||
auto&& zeros = make_empty_tensor( | |||||
grad->comp_node(), inputs[0].get(), grad->dtype()); | |||||
args_[0] = zeros.get(); | |||||
SmallVector<ValueRef> args_(inputs.size() + 1); | |||||
auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype()); | |||||
args_[0] = zeros; | |||||
args_[1] = grad; | args_[1] = grad; | ||||
for (size_t i = 1; i < inputs.size(); ++i) { | for (size_t i = 1; i < inputs.size(); ++i) { | ||||
args_[i + 1] = inputs[i].get(); | |||||
args_[i + 1] = inputs[i]; | |||||
} | } | ||||
ret[0] = python::apply(grad_op_, args_)[0]; | |||||
ret[0] = imperative::apply(ApplyOp(*grad_op_), args_)[0]; | |||||
} | } | ||||
return ret; | return ret; | ||||
}); | }); | ||||
return apply(ctx); | |||||
maker.finalize(); | |||||
return imperative::apply(ApplyOp(op), inputs); | |||||
} | } | ||||
std::optional<apply_result_t> indexingMultiAxisVec_grad_rule( | |||||
ApplyContext& ctx, CustomBackward::Maker& maker) { | |||||
auto&& op = ctx.op->cast_final_safe<IndexingMultiAxisVec>(); | |||||
auto&& grad_op = IndexingSetMultiAxisVec::make(op.items); | |||||
SmallVector<std::shared_ptr<Tensor>> inputs; | |||||
if (input_requires_grad(ctx, 0)) { | |||||
inputs.push_back(get_shape(ctx.args[0])); | |||||
for (size_t i = 1; i < ctx.nargs; ++i) { | |||||
inputs.push_back(ctx.args[i]->copy()); | |||||
std::optional<std::vector<ValueRef>> indexingMultiAxisVec_grad_rule( | |||||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||||
CustomBackward& backward) { | |||||
auto&& indexingMultiAxisVec = op.cast_final_safe<IndexingMultiAxisVec>(); | |||||
auto&& grad_op = IndexingSetMultiAxisVec::make(indexingMultiAxisVec.items); | |||||
SmallVector<ValueRef> inputs2; | |||||
if (inputs_require_grad[0]) { | |||||
inputs2.push_back(get_shape(inputs[0])); | |||||
for (size_t i = 1; i < inputs.size(); ++i) { | |||||
inputs2.push_back(inputs[i]); | |||||
} | } | ||||
} | } | ||||
auto maker = CustomGradMaker(backward, inputs.size()); | |||||
maker.output_size(1).output_captured(0, false); | maker.output_size(1).output_captured(0, false); | ||||
maker.backward([inputs = std::move(inputs), grad_op_ = std::move(grad_op)]( | |||||
BackwardContext&, Tensor* const* grads, size_t ngrads) { | |||||
mgb_assert(ngrads == 1); | |||||
Tensor* grad = grads[0]; | |||||
apply_result_t ret(1); | |||||
maker.backward([inputs = std::move(inputs2), | |||||
grad_op_ = std::move(grad_op)](Span<ValueRef> grads) { | |||||
mgb_assert(grads.size() == 1); | |||||
ValueRef grad = grads[0]; | |||||
std::vector<ValueRef> ret(1); | |||||
if (grad && inputs[0]) { | if (grad && inputs[0]) { | ||||
SmallVector<Tensor*> args_(inputs.size() + 1); | |||||
auto&& zeros = make_empty_tensor( | |||||
grad->comp_node(), inputs[0].get(), grad->dtype()); | |||||
args_[0] = zeros.get(); | |||||
SmallVector<ValueRef> args_(inputs.size() + 1); | |||||
auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype()); | |||||
args_[0] = zeros; | |||||
args_[1] = grad; | args_[1] = grad; | ||||
for (size_t i = 1; i < inputs.size(); ++i) { | for (size_t i = 1; i < inputs.size(); ++i) { | ||||
args_[i + 1] = inputs[i].get(); | |||||
args_[i + 1] = inputs[i]; | |||||
} | } | ||||
ret[0] = python::apply(grad_op_, args_)[0]; | |||||
ret[0] = imperative::apply(ApplyOp(*grad_op_), args_)[0]; | |||||
} | } | ||||
return ret; | return ret; | ||||
}); | }); | ||||
return apply(ctx); | |||||
maker.finalize(); | |||||
return imperative::apply(ApplyOp(op), inputs); | |||||
} | } | ||||
std::optional<apply_result_t> reduce_grad_rule( | |||||
ApplyContext& ctx, CustomBackward::Maker& maker) { | |||||
auto& op = ctx.op->cast_final_safe<Reduce>(); | |||||
if (op.mode == Reduce::Mode::SUM) { | |||||
if (ctx.nargs != 1) { | |||||
return {}; | |||||
} | |||||
std::array<std::shared_ptr<Tensor>, 1> input_shapes; | |||||
if (input_requires_grad(ctx, 0)) { | |||||
input_shapes[0] = get_shape(ctx.args[0]); | |||||
} | |||||
maker.output_size(1).output_captured(0, false); | |||||
maker.backward([shapes = std::move(input_shapes)]( | |||||
BackwardContext&, Tensor* const* grads, size_t ngrads) { | |||||
mgb_assert(ngrads == 1); | |||||
Tensor* grad = grads[0]; | |||||
apply_result_t ret(1); | |||||
if (grad && shapes[0]) { | |||||
ret[0] = broadcast_to(grad, shapes[0].get()); | |||||
} | |||||
return ret; | |||||
}); | |||||
return apply(ctx); | |||||
std::optional<std::vector<ValueRef>> reduce_grad_rule( | |||||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||||
CustomBackward& backward) { | |||||
auto& reduce = op.cast_final_safe<Reduce>(); | |||||
if (reduce.mode != Reduce::Mode::SUM) { | |||||
return {}; | |||||
} | |||||
if (inputs.size() != 1) { | |||||
return {}; | |||||
} | |||||
std::array<ValueRef, 1> input_shapes; | |||||
if (inputs_require_grad[0]) { | |||||
input_shapes[0] = get_shape(inputs[0]); | |||||
} | } | ||||
return {}; | |||||
auto maker = CustomGradMaker(backward, inputs.size()); | |||||
maker.output_size(1).output_captured(0, false); | |||||
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { | |||||
mgb_assert(grads.size() == 1); | |||||
ValueRef grad = grads[0]; | |||||
std::vector<ValueRef> ret(1); | |||||
if (grad && shapes[0]) { | |||||
ret[0] = broadcast_to(grad, shapes[0]); | |||||
} | |||||
return ret; | |||||
}); | |||||
maker.finalize(); | |||||
return imperative::apply(ApplyOp(op), inputs); | |||||
} | } | ||||
std::optional<apply_result_t> addAxis_grad_rule( | |||||
ApplyContext& ctx, CustomBackward::Maker& maker) { | |||||
auto&& op = ctx.op->cast_final_safe<AddAxis>(); | |||||
mgb_assert(ctx.nargs == 1); | |||||
bool flag = input_requires_grad(ctx, 0); | |||||
auto&& grad_op = RemoveAxis::make(op.axis); | |||||
std::optional<std::vector<ValueRef>> addAxis_grad_rule( | |||||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||||
CustomBackward& backward) { | |||||
auto&& addAxis = op.cast_final_safe<AddAxis>(); | |||||
mgb_assert(inputs.size() == 1); | |||||
bool flag = inputs_require_grad[0]; | |||||
auto&& grad_op = RemoveAxis::make(addAxis.axis); | |||||
std::sort(grad_op->axis.begin(), grad_op->axis.end(), std::greater<int32_t>()); | std::sort(grad_op->axis.begin(), grad_op->axis.end(), std::greater<int32_t>()); | ||||
auto maker = CustomGradMaker(backward, inputs.size()); | |||||
maker.output_size(1).output_captured(0, false); | maker.output_size(1).output_captured(0, false); | ||||
maker.backward([grad_op_ = std::move(grad_op), flag_ = flag]( | |||||
BackwardContext&, Tensor* const* grads, size_t ngrads) { | |||||
mgb_assert(ngrads == 1); | |||||
Tensor* grad = grads[0]; | |||||
apply_result_t ret(1); | |||||
maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) { | |||||
mgb_assert(grads.size() == 1); | |||||
ValueRef grad = grads[0]; | |||||
std::vector<ValueRef> ret(1); | |||||
if (grad && flag_) { | if (grad && flag_) { | ||||
ret[0] = python::apply(grad_op_, grad)[0]; | |||||
ret[0] = imperative::apply(*grad_op_, grad)[0]; | |||||
} | } | ||||
return ret; | return ret; | ||||
}); | }); | ||||
return apply(ctx); | |||||
maker.finalize(); | |||||
return imperative::apply(op, inputs); | |||||
} | } | ||||
std::optional<apply_result_t> removeAxis_grad_rule( | |||||
ApplyContext& ctx, CustomBackward::Maker& maker) { | |||||
auto&& op = ctx.op->cast_final_safe<RemoveAxis>(); | |||||
mgb_assert(ctx.nargs == 1); | |||||
bool flag = input_requires_grad(ctx, 0); | |||||
auto&& grad_op = AddAxis::make(op.axis); | |||||
std::optional<std::vector<ValueRef>> removeAxis_grad_rule( | |||||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||||
CustomBackward& backward) { | |||||
auto&& removeAxis = op.cast_final_safe<RemoveAxis>(); | |||||
mgb_assert(inputs.size() == 1); | |||||
bool flag = inputs_require_grad[0]; | |||||
auto&& grad_op = AddAxis::make(removeAxis.axis); | |||||
std::sort(grad_op->axis.begin(), grad_op->axis.end()); | std::sort(grad_op->axis.begin(), grad_op->axis.end()); | ||||
auto maker = CustomGradMaker(backward, inputs.size()); | |||||
maker.output_size(1).output_captured(0, false); | maker.output_size(1).output_captured(0, false); | ||||
maker.backward([grad_op_ = std::move(grad_op), flag_ = flag]( | |||||
BackwardContext&, Tensor* const* grads, size_t ngrads) { | |||||
mgb_assert(ngrads == 1); | |||||
Tensor* grad = grads[0]; | |||||
apply_result_t ret(1); | |||||
maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) { | |||||
mgb_assert(grads.size() == 1); | |||||
ValueRef grad = grads[0]; | |||||
std::vector<ValueRef> ret(1); | |||||
if (grad && flag_) { | if (grad && flag_) { | ||||
ret[0] = python::apply(grad_op_, grad)[0]; | |||||
ret[0] = imperative::apply(*grad_op_, grad)[0]; | |||||
} | } | ||||
return ret; | return ret; | ||||
}); | }); | ||||
return apply(ctx); | |||||
maker.finalize(); | |||||
return imperative::apply(op, inputs); | |||||
} | } | ||||
std::optional<apply_result_t> fastpathcopy_grad_rule( | |||||
ApplyContext& ctx, CustomBackward::Maker& maker) { | |||||
mgb_assert(ctx.nargs == 1); | |||||
std::optional<std::vector<ValueRef>> fastpathcopy_grad_rule( | |||||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||||
CustomBackward& backward) { | |||||
mgb_assert(inputs.size() == 1); | |||||
auto maker = CustomGradMaker(backward, inputs.size()); | |||||
maker.output_size(1).output_captured(0, false); | maker.output_size(1).output_captured(0, false); | ||||
maker.backward([](BackwardContext&, Tensor* const* grads, size_t ngrads) { | |||||
mgb_assert(ngrads == 1); | |||||
Tensor* grad = grads[0]; | |||||
apply_result_t ret(1); | |||||
maker.backward([](Span<ValueRef> grads) { | |||||
mgb_assert(grads.size() == 1); | |||||
ValueRef grad = grads[0]; | |||||
std::vector<ValueRef> ret(1); | |||||
if (grad) { | if (grad) { | ||||
ret[0] = grad->shared_from_this(); | |||||
ret[0] = grad; | |||||
} | } | ||||
return ret; | return ret; | ||||
}); | }); | ||||
return apply(ctx); | |||||
maker.finalize(); | |||||
return imperative::apply(op, inputs); | |||||
} | } | ||||
struct Init { | struct Init { | ||||
Init() { | Init() { | ||||
auto& reg = grad_rule_registry(); | |||||
reg.emplace(Elemwise::typeinfo(), elemwise_grad_rule); | |||||
reg.emplace(Reshape::typeinfo(), reshape_grad_rule); | |||||
reg.emplace(Subtensor::typeinfo(), subtensor_grad_rule); | |||||
reg.emplace(IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule); | |||||
reg.emplace(Reduce::typeinfo(), reduce_grad_rule); | |||||
reg.emplace(AddAxis::typeinfo(), addAxis_grad_rule); | |||||
reg.emplace(RemoveAxis::typeinfo(), removeAxis_grad_rule); | |||||
reg.emplace(FastpathCopy::typeinfo(), fastpathcopy_grad_rule); | |||||
CustomBackward::register_grad_rule(Elemwise::typeinfo(), elemwise_grad_rule); | |||||
CustomBackward::register_grad_rule(Reshape::typeinfo(), reshape_grad_rule); | |||||
CustomBackward::register_grad_rule(Subtensor::typeinfo(), subtensor_grad_rule); | |||||
CustomBackward::register_grad_rule( | |||||
IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule); | |||||
CustomBackward::register_grad_rule(Reduce::typeinfo(), reduce_grad_rule); | |||||
CustomBackward::register_grad_rule(AddAxis::typeinfo(), addAxis_grad_rule); | |||||
CustomBackward::register_grad_rule( | |||||
RemoveAxis::typeinfo(), removeAxis_grad_rule); | |||||
CustomBackward::register_grad_rule( | |||||
FastpathCopy::typeinfo(), fastpathcopy_grad_rule); | |||||
} | } | ||||
} _; | } _; | ||||
@@ -1,245 +0,0 @@ | |||||
/** | |||||
* \file imperative/python/src/intrusive_list.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/utils/metahelper.h" | |||||
namespace mgb::imperative::python::intrusive_list { | |||||
// copy policy | |||||
struct after_t {}; | |||||
struct before_t {}; | |||||
struct disable_t {}; | |||||
template <typename T> | |||||
struct Tail; | |||||
// invariant: next->prev == this | |||||
template <typename T> | |||||
struct Head { | |||||
Tail<T>* next; | |||||
Head(Tail<T>* node = nullptr) : next(node) {} | |||||
Head(const Head<T>&) = delete; | |||||
Head<T>& operator=(const Head<T>&) = delete; | |||||
Head(Head<T>&& rhs) : next(rhs.next) { | |||||
rhs.next = nullptr; | |||||
if (next) { | |||||
next->prev = this; | |||||
} | |||||
} | |||||
Head<T>& operator=(Head<T>&& rhs) { | |||||
mgb_assert(!next); | |||||
next = rhs.next; | |||||
rhs.next = nullptr; | |||||
if (next) { | |||||
next->prev = this; | |||||
} | |||||
return *this; | |||||
} | |||||
~Head() { | |||||
if (next) { | |||||
next->prev = nullptr; | |||||
} | |||||
} | |||||
}; | |||||
// invariant: prev->next == this | |||||
template <typename T> | |||||
struct Tail { | |||||
Head<T>* prev; | |||||
Tail(Head<T>* node = nullptr) : prev(node) {} | |||||
Tail(const Tail<T>&) = delete; | |||||
Tail<T>& operator=(const Tail<T>&) = delete; | |||||
Tail(Tail<T>&& rhs) : prev(rhs.prev) { | |||||
rhs.prev = nullptr; | |||||
if (prev) { | |||||
prev->next = this; | |||||
} | |||||
} | |||||
Tail<T>& operator=(Tail<T>&& rhs) { | |||||
mgb_assert(!prev); | |||||
prev = rhs.prev; | |||||
rhs.prev = nullptr; | |||||
if (prev) { | |||||
prev->next = this; | |||||
} | |||||
return *this; | |||||
} | |||||
~Tail() { | |||||
if (prev) { | |||||
prev->next = nullptr; | |||||
} | |||||
} | |||||
}; | |||||
template <typename T, typename policy> | |||||
struct Node; | |||||
template <typename T> | |||||
class Iterator { | |||||
T* ptr; | |||||
void inc() { ptr = static_cast<T*>(ptr->Head<T>::next); } | |||||
void dec() { ptr = static_cast<T*>(ptr->Head<T>::prev); } | |||||
public: | |||||
Iterator(Head<T>& head) : ptr(static_cast<T*>(head.next)) {} | |||||
Iterator(Tail<T>& tail) : ptr(static_cast<T*>(tail.prev)) {} | |||||
template <typename policy> | |||||
Iterator(Node<T, policy>& node) : ptr(static_cast<T*>(&node)) {} | |||||
T& operator*() { return *static_cast<T*>(ptr); } | |||||
T* operator->() { return static_cast<T*>(ptr); } | |||||
operator bool() { return ptr; } | |||||
bool operator==(const Iterator<T>& rhs) { return ptr == rhs.ptr; } | |||||
Iterator& operator++() { | |||||
inc(); | |||||
return *this; | |||||
} | |||||
Iterator& operator--() { | |||||
dec(); | |||||
return *this; | |||||
} | |||||
Iterator operator++(int) { | |||||
auto ret = *this; | |||||
inc(); | |||||
return ret; | |||||
} | |||||
Iterator operator--(int) { | |||||
auto ret = *this; | |||||
dec(); | |||||
return ret; | |||||
} | |||||
}; | |||||
// Node in a doubly linked list. Unlike std::list, nodes are not owned by a container. | |||||
// Instead, nodes may join or leave a list freely. | |||||
// NOTE: Derived classes have to explicitly declare copy / assignment as default, | |||||
// otherwise the compiler generated version would use the const T& signature, | |||||
// which is deleted. | |||||
template <typename T = void, typename policy = disable_t> | |||||
struct Node : Tail<std::conditional_t<std::is_same_v<T, void>, Node<T, policy>, T>>, | |||||
Head<std::conditional_t<std::is_same_v<T, void>, Node<T, policy>, T>> { | |||||
private: | |||||
using this_t = Node<T, policy>; | |||||
using U = std::conditional_t<std::is_same_v<T, void>, this_t, T>; | |||||
public: | |||||
using head_t = Head<U>; | |||||
using tail_t = Tail<U>; | |||||
using head_t::next; | |||||
using tail_t::prev; | |||||
Node() = default; | |||||
Node(const this_t&) = delete; | |||||
this_t& operator=(const this_t&) = delete; | |||||
//! constructed node is inserted after the input node | |||||
Node(after_t, head_t& node) : tail_t(&node), head_t(node.next) { | |||||
node.next = this; | |||||
if (next) { | |||||
next->prev = this; | |||||
} | |||||
} | |||||
//! constructed node is inserted before the input node | |||||
Node(before_t, tail_t& node) : head_t(&node), tail_t(node.prev) { | |||||
node.prev = this; | |||||
if (prev) { | |||||
prev->next = this; | |||||
} | |||||
} | |||||
Node(this_t&& rhs) : tail_t(rhs.prev), head_t(rhs.next) { | |||||
rhs.prev = nullptr; | |||||
rhs.next = nullptr; | |||||
if (prev) { | |||||
prev->next = this; | |||||
} | |||||
if (next) { | |||||
next->prev = this; | |||||
} | |||||
} | |||||
Node& operator=(this_t&& rhs) { | |||||
unlink(); | |||||
prev = rhs.prev; | |||||
next = rhs.next; | |||||
rhs.prev = nullptr; | |||||
rhs.next = nullptr; | |||||
if (prev) { | |||||
prev->next = this; | |||||
} | |||||
if (next) { | |||||
next->prev = this; | |||||
} | |||||
return *this; | |||||
} | |||||
template < | |||||
typename p = policy, | |||||
typename = std::enable_if_t< | |||||
std::is_same_v<p, before_t> || std::is_same_v<p, after_t>, void>> | |||||
Node(this_t& rhs) : Node(policy{}, rhs) {} | |||||
template < | |||||
typename p = policy, | |||||
typename = std::enable_if_t< | |||||
std::is_same_v<p, before_t> || std::is_same_v<p, after_t>, void>> | |||||
this_t& operator=(this_t& rhs) { | |||||
insert(policy{}, rhs); | |||||
return *this; | |||||
} | |||||
void unlink() { | |||||
if (prev) { | |||||
prev->next = next; | |||||
} | |||||
if (next) { | |||||
next->prev = prev; | |||||
} | |||||
prev = nullptr; | |||||
next = nullptr; | |||||
} | |||||
//! this node is unlinked from its list and inserted after the input node | |||||
void insert(after_t, head_t& node) { | |||||
unlink(); | |||||
prev = &node; | |||||
next = node.next; | |||||
node.next = this; | |||||
if (next) { | |||||
next->prev = this; | |||||
} | |||||
} | |||||
//! this node is unlinked from its list and inserted before the input node | |||||
void insert(before_t, tail_t& node) { | |||||
unlink(); | |||||
next = &node; | |||||
prev = node.prev; | |||||
node.prev = this; | |||||
if (prev) { | |||||
prev->next = this; | |||||
} | |||||
} | |||||
void insert_before(tail_t& node) { insert(before_t{}, node); } | |||||
void insert_after(head_t& node) { insert(after_t{}, node); } | |||||
~Node() { unlink(); } | |||||
}; | |||||
} // namespace mgb::imperative::python::intrusive_list |
@@ -1,42 +0,0 @@ | |||||
/** | |||||
* \file imperative/python/src/module_trace.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "./module_trace.h" | |||||
#include "./helper.h" // include op pybind11 caster | |||||
namespace py = pybind11; | |||||
namespace mgb::imperative::python { | |||||
apply_result_t apply_module_trace(ApplyContext& ctx) { | |||||
apply_result_t outputs; | |||||
auto args = py::tuple(ctx.nargs + 1); | |||||
args[0] = py::cast(ctx.op); | |||||
for (size_t i = 0; i < ctx.nargs; i++) { | |||||
args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this()); | |||||
} | |||||
auto pyout = PyObject_Call(cpp_apply_module_trace, args.ptr(), nullptr); | |||||
if (!pyout) | |||||
throw py::error_already_set(); | |||||
auto ret = py::reinterpret_steal<py::object>(pyout); | |||||
// assumption: python function always returns PyList | |||||
auto tup = py::reinterpret_borrow<py::list>(ret); | |||||
for (auto i = 0; i < tup.size(); i++) { | |||||
auto tw = TensorWrapper::try_cast(tup[i].ptr()); | |||||
outputs.emplace_back(tw->m_tensor); | |||||
} | |||||
return outputs; | |||||
} | |||||
} // namespace mgb::imperative::python |
@@ -11,10 +11,50 @@ | |||||
#pragma once | #pragma once | ||||
#include "megbrain/imperative/transformations/trace.h" | |||||
#include "megbrain/imperative/utils/map.h" | |||||
#include "./tensor.h" | #include "./tensor.h" | ||||
namespace mgb::imperative::python { | namespace mgb::imperative::python { | ||||
apply_result_t apply_module_trace(ApplyContext& ctx); | |||||
namespace py = pybind11; | |||||
class ModuleTraceTransformation final : public Transformation { | |||||
private: | |||||
py::function m_hook_fn; | |||||
int m_enabled = 0; | |||||
std::vector<ValueRef> apply_module_trace_hook( | |||||
const OpDef& op, Span<ValueRef> input_values) { | |||||
py::list input_tws; | |||||
for (auto&& input_value : input_values) { | |||||
input_tws.append(TensorWrapper::make(py_tensor_type, input_value)); | |||||
} | |||||
py::list output_tws = m_hook_fn(py::cast(op.shared_from_this()), *input_tws); | |||||
std::vector<ValueRef> outputs; | |||||
for (auto&& output_tw : output_tws) { | |||||
outputs.push_back( | |||||
TensorWrapper::try_cast(output_tw.ptr())->m_tensor->data()); | |||||
} | |||||
return outputs; | |||||
} | |||||
public: | |||||
ModuleTraceTransformation(py::function hook_fn) : m_hook_fn(hook_fn) {} | |||||
std::vector<ValueRef> apply_transformation( | |||||
const Operator& op, Span<ValueRef> inputs) override { | |||||
if (op.is<ApplyOp>() && m_enabled > 0) { | |||||
auto outputs = apply_module_trace_hook(op.cast<ApplyOp>().op(), inputs); | |||||
return outputs; | |||||
} else { | |||||
return imperative::apply(op, inputs); | |||||
} | |||||
} | |||||
ValueRef unwrap(ValueRef value) override { return value; } | |||||
std::string name() const override { return "ModuleTraceTransformation"; } | |||||
}; | |||||
} // namespace mgb::imperative::python | } // namespace mgb::imperative::python |
@@ -185,7 +185,8 @@ int py_set_scope(PyObject* obj, PyObject* value, void* /* closure */) { | |||||
} | } | ||||
PyGetSetDef PyOp(OpDef)::py_getsetters[] = { | PyGetSetDef PyOp(OpDef)::py_getsetters[] = { | ||||
{const_cast<char*>("scope"), py_get_scope, py_set_scope, "scope", NULL}, | |||||
{const_cast<char*>("scope"), py_get_scope, py_set_scope, | |||||
const_cast<char*>("scope"), NULL}, | |||||
{NULL}}; | {NULL}}; | ||||
Py_hash_t PyOp(OpDef)::tp_hash(PyObject* obj) { | Py_hash_t PyOp(OpDef)::tp_hash(PyObject* obj) { | ||||
@@ -556,12 +557,6 @@ void init_ops(py::module m) { | |||||
m.def( | m.def( | ||||
"delete_rng_handle", | "delete_rng_handle", | ||||
[](size_t handle) { | [](size_t handle) { | ||||
// RNG op might execute after handle released due to async dispatch, so | |||||
// we need sync before delete a handle to avoid memory leak or | |||||
// use-after-free | |||||
if (python::interpreter_for_py->check_available()) { | |||||
python::interpreter_for_py->sync(); | |||||
} | |||||
mgb::CompNode::sync_all(); | mgb::CompNode::sync_all(); | ||||
py_task_q.wait_all_task_finish(); | py_task_q.wait_all_task_finish(); | ||||
rng::delete_handle(handle); | rng::delete_handle(handle); | ||||
@@ -20,6 +20,8 @@ | |||||
#include "pybind11/pybind11.h" | #include "pybind11/pybind11.h" | ||||
#include "./pyext17.h" | #include "./pyext17.h" | ||||
#include "megbrain/imperative/dispatch.h" | |||||
#include "megbrain/imperative/utils/span.h" | |||||
namespace mgb::imperative::python { | namespace mgb::imperative::python { | ||||
@@ -32,126 +34,67 @@ struct ObjectPtr : B { | |||||
} // namespace mgb::imperative::python | } // namespace mgb::imperative::python | ||||
#include "./grad_info.h" // for struct GradInfo | |||||
#include "./trace_info.h" // for struct TraceInfo | |||||
namespace mgb::imperative::python { | namespace mgb::imperative::python { | ||||
struct GradKey; | |||||
extern interpreter::Interpreter::Channel* interpreter_for_py; | extern interpreter::Interpreter::Channel* interpreter_for_py; | ||||
extern PyTypeObject* py_tensor_type; | |||||
class SharedHandle { | |||||
using Handle = interpreter::Interpreter::Handle; | |||||
static_assert(std::is_pointer_v<Handle>); | |||||
std::shared_ptr<std::remove_pointer_t<Handle>> holder; | |||||
public: | |||||
inline explicit SharedHandle(Handle handle) | |||||
: holder(handle, [](auto* h) { | |||||
if (h) { | |||||
interpreter_for_py->del(h); | |||||
} | |||||
}) {} | |||||
SharedHandle(const SharedHandle&) = default; | |||||
SharedHandle& operator=(const SharedHandle&) = default; | |||||
SharedHandle(SharedHandle&&) = default; | |||||
SharedHandle& operator=(SharedHandle&&) = default; | |||||
inline Handle get() { return holder.get(); } | |||||
}; | |||||
// impl in grad.cpp | |||||
class GradInfoCollection { | |||||
struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { | |||||
private: | private: | ||||
SmallVector<GradInfo> m_storage; | |||||
protected: | |||||
void _shrink(); | |||||
std::string m_name; | |||||
ValueRef m_data; | |||||
public: | public: | ||||
bool contains(GradKey* key); | |||||
GradInfo& operator[](GradKey* key); | |||||
GradInfo& at(GradKey* key); | |||||
bool empty() { | |||||
_shrink(); | |||||
return m_storage.empty(); | |||||
} | |||||
auto begin() { | |||||
_shrink(); | |||||
return m_storage.begin(); | |||||
} | |||||
auto end() { | |||||
_shrink(); | |||||
return m_storage.end(); | |||||
} | |||||
size_t count(GradKey* key) { return contains(key) ? 1 : 0; } | |||||
}; | |||||
struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { | |||||
using flags_t = uint64_t; | |||||
struct Flags { | |||||
static constexpr flags_t SCALAR = 1; | |||||
static constexpr flags_t GRAD = 1 << 1; | |||||
static constexpr flags_t TRACE = 1 << 2; | |||||
static constexpr flags_t MODULE_TRACE = 1 << 3; | |||||
}; | |||||
flags_t m_flags = 0; | |||||
GradInfoCollection m_grad_info_dict; | |||||
TraceInfo m_trace_info; | |||||
SharedHandle m_handle; | |||||
std::string user_custom_name; | |||||
std::string automatic_name; | |||||
cg::VarNode* m_var; | |||||
pybind11::object m_module_trace_info; | |||||
using Handle = interpreter::Interpreter::Handle; | using Handle = interpreter::Interpreter::Handle; | ||||
inline Tensor() : m_handle(nullptr), m_var(nullptr) {} | |||||
inline explicit Tensor(Handle handle) : m_handle(handle), m_var(nullptr) {} | |||||
inline explicit Tensor(SharedHandle handle) | |||||
: m_handle(std::move(handle)), m_var(nullptr) {} | |||||
inline explicit Tensor(cg::VarNode* var) : m_handle(nullptr), m_var(var) {} | |||||
inline explicit Tensor(ValueRef data) : m_data{data} {} | |||||
~Tensor() = default; | ~Tensor() = default; | ||||
inline std::shared_ptr<Tensor> copy() { | inline std::shared_ptr<Tensor> copy() { | ||||
auto ret = std::make_shared<Tensor>(m_handle); | |||||
ret->m_flags = m_flags; | |||||
ret->m_grad_info_dict = m_grad_info_dict; | |||||
ret->m_trace_info = m_trace_info; | |||||
ret->m_var = m_var; | |||||
auto ret = std::make_shared<Tensor>(m_data.unwrap()); | |||||
ret->m_name = m_name; | |||||
return ret; | return ret; | ||||
} | } | ||||
inline DType dtype() { | |||||
if (m_var) { | |||||
return m_var->dtype(); | |||||
inline DType dtype() { return *data().dtype(); } | |||||
inline CompNode comp_node() { return *data().device(); } | |||||
inline std::optional<ValueShape> shape() { | |||||
auto shape = data().shape(); | |||||
if (!shape) { | |||||
return {}; | |||||
} | } | ||||
return interpreter_for_py->get_dtype(m_handle.get()); | |||||
return *shape; | |||||
} | } | ||||
inline CompNode comp_node() { | |||||
if (m_var) { | |||||
return m_var->comp_node(); | |||||
inline HostValue::ref_t numpy() { return data().numpy(); } | |||||
inline void reset(ValueRef value) { | |||||
m_data = value; | |||||
if (!m_name.empty()) { | |||||
set_name(m_name); | |||||
} | } | ||||
return interpreter_for_py->get_device(m_handle.get()); | |||||
} | } | ||||
inline TensorShape shape() { | |||||
if (m_var) { | |||||
return m_var->shape(); | |||||
inline ValueRef data() { return m_data.unwrap(); } | |||||
bool is_scalar() { return data().is_scalar(); } | |||||
inline std::string name() { return m_name; } | |||||
inline void set_name(std::string name) { | |||||
m_name = name; | |||||
if (!name.empty()) { | |||||
auto output = imperative::apply(RenameValue(name), m_data)[0]; | |||||
m_data = output; | |||||
} | } | ||||
return interpreter_for_py->get_shape(m_handle.get()); | |||||
} | } | ||||
}; | }; | ||||
struct TensorWrapper { | struct TensorWrapper { | ||||
public: | |||||
std::shared_ptr<Tensor> m_tensor; | std::shared_ptr<Tensor> m_tensor; | ||||
inline TensorWrapper(std::shared_ptr<Tensor> tensor = {}) | inline TensorWrapper(std::shared_ptr<Tensor> tensor = {}) | ||||
: m_tensor(std::move(tensor)) {} | |||||
: m_tensor(std::move(tensor)) { | |||||
mgb_assert(tensor, "empty storage"); | |||||
} | |||||
inline TensorWrapper(ValueRef value) : m_tensor(std::make_shared<Tensor>(value)) {} | |||||
TensorWrapper(PyObject* args, PyObject* kwargs); | TensorWrapper(PyObject* args, PyObject* kwargs); | ||||
~TensorWrapper() = default; | ~TensorWrapper() = default; | ||||
@@ -191,33 +134,17 @@ struct TensorWrapper { | |||||
void reset(PyObject*); | void reset(PyObject*); | ||||
PyObject* detach(); | PyObject* detach(); | ||||
PyObject* isscalar(); | PyObject* isscalar(); | ||||
void setscalar(); | |||||
void unsetscalar(); | |||||
PyObject* _dev_tensor(); | PyObject* _dev_tensor(); | ||||
void _drop(); | void _drop(); | ||||
PyObject* varnode(); | PyObject* varnode(); | ||||
void reset_varnode(); | |||||
PyObject* handle(); | |||||
void set_handle(PyObject*); | |||||
PyObject* mixin_handle(); | |||||
PyObject* recording(); | PyObject* recording(); | ||||
PyObject* copied(); | PyObject* copied(); | ||||
void set_mixin_handle(PyObject*); | |||||
void set_recording(PyObject*); | |||||
PyObject* compiled_info(); | |||||
void set_compiled_info(PyObject*); | |||||
PyObject* trace_mixin_info(); | |||||
void set_trace_mixin_info(PyObject*); | |||||
PyObject* module_trace_info(); | PyObject* module_trace_info(); | ||||
void set_module_trace_info(PyObject*); | void set_module_trace_info(PyObject*); | ||||
PyObject* user_custom_name(); | |||||
void set_user_custom_name(PyObject*); | |||||
PyObject* automatic_name(); | |||||
void set_automatic_name(PyObject*); | |||||
void _set_name(PyObject*); | |||||
PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); }; | PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); }; | ||||
PyObject* _detail(); | |||||
void _watch(); | |||||
}; | }; | ||||
struct PySymbolVar { | struct PySymbolVar { | ||||
@@ -230,113 +157,8 @@ struct PySymbolVar { | |||||
PyObject* py_apply( | PyObject* py_apply( | ||||
PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */); | PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */); | ||||
struct ApplyContext { | |||||
static Tensor::flags_t global_disable; | |||||
static Tensor::flags_t global_enable; | |||||
Tensor::flags_t flags = 0; | |||||
std::shared_ptr<OpDef> op; | |||||
Tensor* const* args; | |||||
size_t nargs; | |||||
PyTypeObject* pytype = nullptr; | |||||
bool backward = false; | |||||
class scoped_disable : NonCopyableObj { | |||||
Tensor::flags_t saved_flags; | |||||
public: | |||||
scoped_disable(Tensor::flags_t flags) | |||||
: saved_flags(ApplyContext::global_disable) { | |||||
ApplyContext::global_disable |= flags; | |||||
} | |||||
~scoped_disable() { ApplyContext::global_disable = saved_flags; } | |||||
}; | |||||
}; | |||||
using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>; | |||||
apply_result_t apply(ApplyContext& ctx); | |||||
template <typename T> | |||||
decltype(auto) resolve_arrow(T&& p) { | |||||
if constexpr (std::is_pointer_v<std::remove_reference_t<T>>) { | |||||
auto* ret = p; | |||||
return ret; | |||||
} else { | |||||
auto probe = [](auto&& p) -> decltype(p.operator->()) {}; | |||||
if constexpr (std::is_invocable_v<decltype(probe), decltype(p)>) { | |||||
return resolve_arrow(p.operator->()); | |||||
} else { | |||||
return std::forward<T>(p); | |||||
} | |||||
} | |||||
} | |||||
template <typename... Args> | |||||
constexpr bool is_all_tensor_ptr = | |||||
(... && std::is_same_v<decltype(resolve_arrow(std::declval<Args>())), Tensor*>); | |||||
template <typename... Args, std::enable_if_t<is_all_tensor_ptr<Args...>, int> = 0> | |||||
apply_result_t apply(std::shared_ptr<OpDef> op, Args&&... args) { | |||||
ApplyContext ctx; | |||||
Tensor* arg_arr[] = {resolve_arrow(args)...}; | |||||
ctx.flags = (0 | ... | args->m_flags); | |||||
ctx.args = arg_arr; | |||||
ctx.nargs = sizeof...(args); | |||||
ctx.op = std::move(op); | |||||
return apply(ctx); | |||||
} | |||||
inline auto apply(std::shared_ptr<OpDef> op, Tensor* const* args, size_t nargs) { | |||||
ApplyContext ctx; | |||||
ctx.op = std::move(op); | |||||
ctx.nargs = nargs; | |||||
ctx.args = args; | |||||
for (size_t i = 0; i < nargs; ++i) { | |||||
ctx.flags |= args[i]->m_flags; | |||||
} | |||||
return apply(ctx); | |||||
} | |||||
template <typename T> | |||||
auto apply(std::shared_ptr<OpDef> op, T&& tensors) -> std::enable_if_t< | |||||
std::is_same_v<decltype(resolve_arrow(tensors[0])), Tensor*>, apply_result_t> { | |||||
size_t nargs = tensors.size(); | |||||
Tensor* args[nargs]; | |||||
for (size_t i = 0; i < nargs; ++i) { | |||||
args[i] = resolve_arrow(tensors[i]); | |||||
} | |||||
return apply(op, args, nargs); | |||||
} | |||||
std::shared_ptr<Tensor> make_const(imperative::TensorPtr value); | |||||
inline auto apply(Subgraph graph, Tensor* const* args, size_t nargs) { | |||||
SmallVector<std::shared_ptr<Tensor>> inputs; | |||||
for (size_t i = 0; i < nargs; ++i) { | |||||
inputs.push_back(args[i]->shared_from_this()); | |||||
} | |||||
auto apply_functor = [](std::shared_ptr<OpDef> op, | |||||
SmallVector<std::shared_ptr<Tensor>> inputs, | |||||
size_t) { return apply(op, std::move(inputs)); }; | |||||
return graph.apply(inputs, apply_functor, &make_const); | |||||
} | |||||
template <typename T> | |||||
auto apply(Subgraph graph, T&& tensors) -> std::enable_if_t< | |||||
std::is_same_v<std::decay_t<decltype(tensors[0])>, Tensor*>, apply_result_t> { | |||||
size_t nargs = tensors.size(); | |||||
Tensor* args[nargs]; | |||||
for (size_t i = 0; i < nargs; ++i) { | |||||
args[i] = resolve_arrow(tensors[i]); | |||||
} | |||||
return apply(graph, args, nargs); | |||||
} | |||||
void init_tensor(pybind11::module); | void init_tensor(pybind11::module); | ||||
extern PyObject* cpp_apply_with_tracing; | |||||
extern PyObject* cpp_apply_backward_varnode; | |||||
extern PyObject* cpp_apply_module_trace; | extern PyObject* cpp_apply_module_trace; | ||||
} // namespace mgb::imperative::python | } // namespace mgb::imperative::python | ||||
@@ -1,63 +0,0 @@ | |||||
/** | |||||
* \file imperative/python/src/trace.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "./trace.h" | |||||
#include "./helper.h" | |||||
#include "megbrain/imperative/ops/autogen.h" | |||||
namespace py = pybind11; | |||||
namespace mgb::imperative::python { | |||||
apply_result_t apply_trace(ApplyContext& ctx) { | |||||
apply_result_t outputs; | |||||
if (ctx.backward) { | |||||
// reach here when compiled=True | |||||
auto args = py::tuple(ctx.nargs + 1); | |||||
args[0] = py::cast(ctx.op); | |||||
for (size_t i = 0; i < ctx.nargs; i++) { | |||||
args[i + 1] = py::cast(ctx.args[i]->m_var); | |||||
} | |||||
py::object pyout = py::reinterpret_steal<py::object>( | |||||
PyObject_Call(cpp_apply_backward_varnode, args.ptr(), nullptr)); | |||||
if (!pyout) | |||||
throw py::error_already_set(); | |||||
// assumption: python function always returns PyList | |||||
auto tup = py::reinterpret_borrow<py::list>(pyout); | |||||
for (size_t i = 0; i < tup.size(); i++) { | |||||
auto pitem = tup[i].cast<cg::VarNode*>(); | |||||
outputs.emplace_back(std::make_shared<Tensor>(pitem)); | |||||
} | |||||
return outputs; | |||||
} | |||||
auto args = py::tuple(ctx.nargs + 1); | |||||
args[0] = py::cast(ctx.op); | |||||
for (size_t i = 0; i < ctx.nargs; i++) { | |||||
args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this()); | |||||
} | |||||
auto pyout = PyObject_Call(cpp_apply_with_tracing, args.ptr(), nullptr); | |||||
if (!pyout) | |||||
throw py::error_already_set(); | |||||
// assumption: python function always returns PyList | |||||
auto tup = py::reinterpret_steal<py::list>(pyout); | |||||
for (size_t i = 0; i < tup.size(); i++) { | |||||
auto tw = TensorWrapper::try_cast(tup[i].ptr()); | |||||
outputs.emplace_back(tw->m_tensor); | |||||
} | |||||
return outputs; | |||||
} | |||||
} // namespace mgb::imperative::python |
@@ -1,28 +0,0 @@ | |||||
/** | |||||
* \file imperative/python/src/trace.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include <stdexcept> | |||||
#include "./tensor.h" | |||||
namespace mgb::imperative::python { | |||||
class TraceReadError : public std::exception { | |||||
public: | |||||
explicit TraceReadError(const char* m) : message{m} {} | |||||
const char* what() const noexcept override { return message.c_str(); } | |||||
private: | |||||
std::string message = ""; | |||||
}; | |||||
apply_result_t apply_trace(ApplyContext& ctx); | |||||
} // namespace mgb::imperative::python |
@@ -1,49 +0,0 @@ | |||||
/** | |||||
* \file imperative/python/src/trace_info.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "Python.h" | |||||
#include "inttypes.h" | |||||
namespace mgb::imperative::python { | |||||
struct TraceInfo { | |||||
int64_t mixin_handle = -1; | |||||
bool recording = false; | |||||
// refer to CompiledTensorProxy in tracing.py, works from second trace step | |||||
PyObject* compiled_info = nullptr; | |||||
// refer to TensorInfo in tracing.py, only works in first trace step | |||||
PyObject* trace_mixin_info = nullptr; | |||||
TraceInfo() = default; | |||||
TraceInfo& operator=(const TraceInfo& that) { | |||||
mixin_handle = that.mixin_handle; | |||||
recording = that.recording; | |||||
trace_mixin_info = that.trace_mixin_info; | |||||
Py_XINCREF(trace_mixin_info); | |||||
compiled_info = that.compiled_info; | |||||
Py_XINCREF(compiled_info); | |||||
return *this; | |||||
} | |||||
~TraceInfo() { | |||||
Py_XDECREF(trace_mixin_info); | |||||
Py_XDECREF(compiled_info); | |||||
} | |||||
private: | |||||
TraceInfo(const TraceInfo& that) = default; | |||||
}; | |||||
} // namespace mgb::imperative::python |
@@ -16,10 +16,6 @@ import megengine.module | |||||
from megengine import Parameter | from megengine import Parameter | ||||
from megengine.core._imperative_rt.core2 import sync | from megengine.core._imperative_rt.core2 import sync | ||||
from megengine.device import get_device_count | from megengine.device import get_device_count | ||||
from megengine.experimental.autograd import ( | |||||
disable_higher_order_directive, | |||||
enable_higher_order_directive, | |||||
) | |||||
from megengine.jit import trace as _trace | from megengine.jit import trace as _trace | ||||
from megengine.module import Linear, Module | from megengine.module import Linear, Module | ||||
@@ -45,13 +41,3 @@ def skip_distributed(request): | |||||
platform.system() | platform.system() | ||||
) | ) | ||||
) | ) | ||||
@pytest.fixture(autouse=True) | |||||
def resolve_require_higher_order_directive(request): | |||||
marker = request.node.get_closest_marker("require_higher_order_directive") | |||||
if marker: | |||||
enable_higher_order_directive() | |||||
yield | |||||
if marker: | |||||
disable_higher_order_directive() |
@@ -146,5 +146,5 @@ def test_dump_bn_train_mode(): | |||||
data = mge.tensor(np.random.random((10, 10, 10, 10))) | data = mge.tensor(np.random.random((10, 10, 10, 10))) | ||||
bn_train(data) | bn_train(data) | ||||
with pytest.raises(AssertionError): | |||||
with pytest.raises(RuntimeError): | |||||
bn_train.dump("test.mge") | bn_train.dump("test.mge") |
@@ -17,7 +17,7 @@ import megengine.distributed as dist | |||||
import megengine.functional as F | import megengine.functional as F | ||||
import megengine.module as M | import megengine.module as M | ||||
import megengine.optimizer as optim | import megengine.optimizer as optim | ||||
from megengine.autodiff import GradManager | |||||
from megengine.autodiff import Function, GradManager | |||||
from megengine.jit import trace | from megengine.jit import trace | ||||
@@ -214,7 +214,7 @@ def test_remote_grad(trace_mode): | |||||
x = dist.functional.remote_recv(rank - 1) | x = dist.functional.remote_recv(rank - 1) | ||||
y = m(x) | y = m(x) | ||||
if rank != size - 1: | if rank != size - 1: | ||||
dist.functional.remote_send(y, dest_rank=rank + 1) | |||||
x = dist.functional.remote_send(y, dest_rank=rank + 1) | |||||
gm.backward() | gm.backward() | ||||
else: | else: | ||||
y = y.mean() | y = y.mean() | ||||
@@ -224,7 +224,7 @@ def test_remote_grad(trace_mode): | |||||
if trace_mode is not None: | if trace_mode is not None: | ||||
train_func = trace(symbolic=trace_mode)(train_func) | train_func = trace(symbolic=trace_mode)(train_func) | ||||
for i in range(3): | |||||
for i in range(1): | |||||
train_func(x) | train_func(x) | ||||
worker() | worker() | ||||
@@ -340,7 +340,6 @@ def test_broadcast_grad(trace_mode): | |||||
worker() | worker() | ||||
@pytest.mark.require_higher_order_directive() | |||||
def test_2nd_grad_with_manager(): | def test_2nd_grad_with_manager(): | ||||
x_np = np.random.rand(10).astype("float32") | x_np = np.random.rand(10).astype("float32") | ||||
x = mge.tensor(x_np) | x = mge.tensor(x_np) | ||||
@@ -359,7 +358,6 @@ def test_2nd_grad_with_manager(): | |||||
) | ) | ||||
@pytest.mark.require_higher_order_directive() | |||||
def test_grad_manager_group(): | def test_grad_manager_group(): | ||||
x_np = np.random.rand(10).astype("float32") | x_np = np.random.rand(10).astype("float32") | ||||
x = mge.tensor(x_np) | x = mge.tensor(x_np) | ||||
@@ -376,7 +374,6 @@ def test_grad_manager_group(): | |||||
x.grad = None | x.grad = None | ||||
@pytest.mark.require_higher_order_directive() | |||||
def test_grad_manager_group_visibility(): | def test_grad_manager_group_visibility(): | ||||
x_np = np.random.rand(10).astype("float32") | x_np = np.random.rand(10).astype("float32") | ||||
x = mge.tensor(x_np) | x = mge.tensor(x_np) | ||||
@@ -392,7 +389,6 @@ def test_grad_manager_group_visibility(): | |||||
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) | np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) | ||||
@pytest.mark.require_higher_order_directive() | |||||
def test_grad_manager_visibility_by_order(): | def test_grad_manager_visibility_by_order(): | ||||
x_np = np.random.rand(10).astype("float32") | x_np = np.random.rand(10).astype("float32") | ||||
x = mge.tensor(x_np) | x = mge.tensor(x_np) | ||||
@@ -410,7 +406,6 @@ def test_grad_manager_visibility_by_order(): | |||||
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) | np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) | ||||
@pytest.mark.require_higher_order_directive() | |||||
@pytest.mark.parametrize("target", [F.cos, F.sin, lambda x: x * 2 + 1]) | @pytest.mark.parametrize("target", [F.cos, F.sin, lambda x: x * 2 + 1]) | ||||
def test_emulate_forward_mode_with_reverse_mode(target): | def test_emulate_forward_mode_with_reverse_mode(target): | ||||
def jvp(inp, expr): | def jvp(inp, expr): | ||||
@@ -434,3 +429,43 @@ def test_emulate_forward_mode_with_reverse_mode(target): | |||||
np.testing.assert_almost_equal(y.numpy(), y1.numpy(), decimal=5) | np.testing.assert_almost_equal(y.numpy(), y1.numpy(), decimal=5) | ||||
np.testing.assert_almost_equal(dy.numpy(), dy1.numpy(), decimal=3) | np.testing.assert_almost_equal(dy.numpy(), dy1.numpy(), decimal=3) | ||||
def test_2nd_grad_with_custom_gradient(): | |||||
class MySin(Function): | |||||
def forward(self, x): | |||||
self.inp = x | |||||
x = mge.Tensor(x.numpy()) | |||||
y = F.sin(x) | |||||
return y | |||||
def backward(self, dy): | |||||
dx = F.cos(self.inp) * dy | |||||
return dx | |||||
class MyCos(Function): | |||||
def forward(self, x): | |||||
self.inp = x | |||||
x = mge.Tensor(x.numpy()) | |||||
y = F.cos(x) | |||||
return y | |||||
def backward(self, dy): | |||||
dx = -MySin()(self.inp) * dy | |||||
return dx | |||||
x_np = np.random.rand(10).astype("float32") | |||||
x = mge.tensor(x_np) | |||||
gm = GradManager().attach([x]) | |||||
gm2 = GradManager().attach([x]) | |||||
with gm: | |||||
with gm2: | |||||
y = MyCos()(x) | |||||
gm2.backward(y) | |||||
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) | |||||
gm.backward(x.grad) | |||||
np.testing.assert_almost_equal( | |||||
x.grad.numpy(), -np.sin(x_np) - np.cos(x_np), decimal=5 | |||||
) |
@@ -7,8 +7,6 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import gc | import gc | ||||
import platform | |||||
import weakref | |||||
import numpy as np | import numpy as np | ||||
import pytest | import pytest | ||||
@@ -60,24 +58,20 @@ def test_dist_grad(): | |||||
def worker(): | def worker(): | ||||
rank = dist.get_rank() | rank = dist.get_rank() | ||||
if rank == 0: | if rank == 0: | ||||
grad = Grad() | |||||
x = as_tensor(x_np) | |||||
grad.wrt(x, callback=save_to(x)) | |||||
# need a placeholder to trace operator | |||||
remote_send(x, 1) | |||||
recv_x = remote_recv(1) | |||||
y = recv_x * recv_x | |||||
grad([y], [as_tensor(np.ones_like(x_np))]) | |||||
with Grad() as grad: | |||||
x = as_tensor(x_np) | |||||
grad.wrt(x, callback=save_to(x)) | |||||
# need a placeholder to trace operator | |||||
remote_send(x, 1) | |||||
recv_x = remote_recv(1) | |||||
y = recv_x * recv_x | |||||
grad([y], [as_tensor(np.ones_like(x_np))]) | |||||
np.testing.assert_almost_equal(x.grad.numpy(), x.numpy() * 2) | np.testing.assert_almost_equal(x.grad.numpy(), x.numpy() * 2) | ||||
elif rank == 1: | elif rank == 1: | ||||
grad = Grad() | |||||
recv_x = remote_recv(0) | |||||
remote_send(recv_x, 0) | |||||
grad([], []) | |||||
with Grad() as grad: | |||||
recv_x = remote_recv(0) | |||||
remote_send(recv_x, 0) | |||||
grad([], []) | |||||
worker() | worker() | ||||
@@ -86,11 +80,11 @@ def test_grad(): | |||||
x_np = np.random.rand(10).astype("float32") | x_np = np.random.rand(10).astype("float32") | ||||
x = as_tensor(x_np) | x = as_tensor(x_np) | ||||
grad = Grad().wrt(x, callback=save_to(x)) | |||||
y = cos(x) | |||||
with Grad() as grad: | |||||
grad.wrt(x, callback=save_to(x)) | |||||
y = cos(x) | |||||
grad(y, as_tensor(np.ones_like(x_np))) | |||||
grad(y, as_tensor(np.ones_like(x_np))) | |||||
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np)) | np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np)) | ||||
@@ -98,12 +92,12 @@ def test_grad_2(): | |||||
x_np = np.random.rand(10).astype("float32") | x_np = np.random.rand(10).astype("float32") | ||||
x = as_tensor(x_np) | x = as_tensor(x_np) | ||||
grad = Grad().wrt(x, callback=save_to(x)) | |||||
with Grad() as grad: | |||||
grad.wrt(x, callback=save_to(x)) | |||||
y = mul(x, x) | |||||
y = mul(y, y) | |||||
grad(y, as_tensor(np.ones_like(x_np))) | |||||
y = mul(x, x) | |||||
y = mul(y, y) | |||||
grad(y, as_tensor(np.ones_like(x_np))) | |||||
np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) | np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) | ||||
@@ -113,32 +107,31 @@ def test_2nd_grad(): | |||||
x = as_tensor(x_np) | x = as_tensor(x_np) | ||||
ones = as_tensor(np.ones_like(x_np)) | ones = as_tensor(np.ones_like(x_np)) | ||||
grad = Grad().wrt(x, callback=save_to(x)) | |||||
grad._priority = -1 | |||||
grad2 = Grad().wrt(x, callback=save_to(x)) | |||||
grad2._priority = 0 | |||||
y = cos(x) | |||||
with Grad("grad2") as grad2: | |||||
with Grad("grad") as grad: | |||||
grad2.wrt(x, callback=save_to(x)) | |||||
grad.wrt(x, callback=save_to(x)) | |||||
y = cos(x) | |||||
grad(y, ones) | |||||
z = x.grad | |||||
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) | |||||
grad(y, ones) | |||||
z = x.grad | |||||
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) | |||||
x.grad = None | |||||
grad2(z, ones) | |||||
x.grad = None | |||||
grad2(z, ones) | |||||
np.testing.assert_almost_equal(x.grad.numpy(), -np.cos(x_np), decimal=5) | |||||
np.testing.assert_almost_equal(x.grad.numpy(), -np.cos(x_np), decimal=5) | |||||
def test_grad_with_tensor_wrapper(): | def test_grad_with_tensor_wrapper(): | ||||
x_np = np.random.rand(10).astype("float32") | x_np = np.random.rand(10).astype("float32") | ||||
x = mge.Tensor(x_np) | x = mge.Tensor(x_np) | ||||
grad = Grad().wrt(x, callback=save_to(x)) | |||||
with Grad() as grad: | |||||
grad.wrt(x, callback=save_to(x)) | |||||
y = mul(x, x) | |||||
y = mul(y, y) | |||||
grad(y, mge.Tensor(np.ones_like(x_np))) | |||||
y = mul(x, x) | |||||
y = mul(y, y) | |||||
grad(y, mge.Tensor(np.ones_like(x_np))) | |||||
np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) | np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) | ||||
@@ -162,18 +155,21 @@ def test_release(): | |||||
@check | @check | ||||
def _(): | def _(): | ||||
g = Grad().wrt(x) | |||||
y = x * x | |||||
g(y, dy) | |||||
with Grad() as g: | |||||
g.wrt(x) | |||||
y = x * x | |||||
g(y, dy) | |||||
@check | @check | ||||
def _(): | def _(): | ||||
with Grad().wrt(x): | |||||
with Grad() as g: | |||||
g.wrt(x) | |||||
pass | pass | ||||
@check | @check | ||||
def _(): | def _(): | ||||
with Grad().wrt(x): | |||||
with Grad() as g: | |||||
g.wrt(x) | |||||
y = x * x | y = x * x | ||||
@@ -181,12 +177,12 @@ def test_grad_inplace(): | |||||
x_np = np.random.rand(10).astype("float32") | x_np = np.random.rand(10).astype("float32") | ||||
x = mge.Tensor(x_np) | x = mge.Tensor(x_np) | ||||
grad = Grad().wrt(x, callback=save_to(x)) | |||||
with Grad() as grad: | |||||
grad.wrt(x, callback=save_to(x)) | |||||
y = mul(x, x) | |||||
y *= y | |||||
grad(y, mge.Tensor(np.ones_like(x_np))) | |||||
y = mul(x, x) | |||||
y *= y | |||||
grad(y, mge.Tensor(np.ones_like(x_np))) | |||||
np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) | np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) | ||||
@@ -196,11 +192,11 @@ def test_identity(): | |||||
dy_np = np.random.rand(*x.shape).astype("float32") | dy_np = np.random.rand(*x.shape).astype("float32") | ||||
dy = mge.Tensor(dy_np) | dy = mge.Tensor(dy_np) | ||||
grad = Grad().wrt(x, callback=save_to(x)) | |||||
(y,) = apply(Identity(), x) | |||||
with Grad() as grad: | |||||
grad.wrt(x, callback=save_to(x)) | |||||
(y,) = apply(Identity(), x) | |||||
grad(y, dy) | |||||
grad(y, dy) | |||||
np.testing.assert_array_equal(x.grad.numpy(), dy_np) | np.testing.assert_array_equal(x.grad.numpy(), dy_np) | ||||
@@ -220,15 +216,14 @@ def test_elemwise_add(): | |||||
refs["y"] = TensorWeakRef(y) | refs["y"] = TensorWeakRef(y) | ||||
return x + y | return x + y | ||||
grad = Grad().wrt(x, callback=save_to(x)) | |||||
z = f(x, y) | |||||
del y | |||||
with Grad() as grad: | |||||
grad.wrt(x, callback=save_to(x)) | |||||
z = f(x, y) | |||||
del y | |||||
for k, r in refs.items(): | |||||
assert r() is None | |||||
grad(z, dz) | |||||
for k, r in refs.items(): | |||||
assert r() is None | |||||
grad(z, dz) | |||||
np.testing.assert_almost_equal(x.grad.numpy(), dz_np.sum(0) * 2, decimal=5) | np.testing.assert_almost_equal(x.grad.numpy(), dz_np.sum(0) * 2, decimal=5) | ||||
@@ -245,13 +240,12 @@ def test_elemwise_relu(): | |||||
refs["x"] = TensorWeakRef(x) | refs["x"] = TensorWeakRef(x) | ||||
return relu(x) | return relu(x) | ||||
grad = Grad().wrt(x, callback=save_to(x)) | |||||
z = f(x) | |||||
assert refs["x"]() is None | |||||
with Grad() as grad: | |||||
grad.wrt(x, callback=save_to(x)) | |||||
z = f(x) | |||||
assert refs["x"]() is None | |||||
grad(z, dz) | |||||
grad(z, dz) | |||||
np.testing.assert_almost_equal(x.grad.numpy(), [2.0, 0]) | np.testing.assert_almost_equal(x.grad.numpy(), [2.0, 0]) | ||||
@@ -269,21 +263,21 @@ def test_reshape(): | |||||
x_np = np.random.rand(2, 5).astype("float32") | x_np = np.random.rand(2, 5).astype("float32") | ||||
x = mge.Tensor(x_np) | x = mge.Tensor(x_np) | ||||
grad = Grad().wrt(x, callback=save_to(x)) | |||||
with Grad() as grad: | |||||
grad.wrt(x, callback=save_to(x)) | |||||
refs = {} | |||||
refs = {} | |||||
def f(x): | |||||
x = x * 1 | |||||
y = x.reshape(5, 2) | |||||
refs["x"] = TensorWeakRef(x) | |||||
return y | |||||
def f(x): | |||||
x = x * 1 | |||||
y = x.reshape(5, 2) | |||||
refs["x"] = TensorWeakRef(x) | |||||
return y | |||||
y = f(x) | |||||
for _, r in refs.items(): | |||||
assert r() is None | |||||
grad(y, F.ones_like(y)) | |||||
y = f(x) | |||||
for _, r in refs.items(): | |||||
assert r() is None | |||||
grad(y, F.ones_like(y)) | |||||
np.testing.assert_equal(np.ones((2, 5), dtype=np.float32), x.grad.numpy()) | np.testing.assert_equal(np.ones((2, 5), dtype=np.float32), x.grad.numpy()) | ||||
@@ -291,21 +285,21 @@ def test_subtensor(): | |||||
x_np = np.random.rand(3, 3).astype("float32") | x_np = np.random.rand(3, 3).astype("float32") | ||||
x = mge.Tensor(x_np) | x = mge.Tensor(x_np) | ||||
grad = Grad().wrt(x, callback=save_to(x)) | |||||
refs = {} | |||||
with Grad() as grad: | |||||
grad.wrt(x, callback=save_to(x)) | |||||
refs = {} | |||||
def f(x): | |||||
x = x * 1 | |||||
y = x[1:-1, :2] | |||||
refs["x"] = TensorWeakRef(x) | |||||
return y | |||||
def f(x): | |||||
x = x * 1 | |||||
y = x[1:-1, :2] | |||||
refs["x"] = TensorWeakRef(x) | |||||
return y | |||||
y = f(x) | |||||
for _, r in refs.items(): | |||||
assert r() is None | |||||
y = f(x) | |||||
for _, r in refs.items(): | |||||
assert r() is None | |||||
grad(y, F.ones_like(y)) | |||||
grad(y, F.ones_like(y)) | |||||
np.testing.assert_equal( | np.testing.assert_equal( | ||||
np.array([[0, 0, 0], [1, 1, 0], [0, 0, 0]], dtype=np.float32), x.grad.numpy() | np.array([[0, 0, 0], [1, 1, 0], [0, 0, 0]], dtype=np.float32), x.grad.numpy() | ||||
) | ) | ||||
@@ -315,21 +309,21 @@ def test_IndexingMultiAxisVec(): | |||||
x_np = np.random.rand(3, 3).astype("float32") | x_np = np.random.rand(3, 3).astype("float32") | ||||
x = mge.Tensor(x_np) | x = mge.Tensor(x_np) | ||||
grad = Grad().wrt(x, callback=save_to(x)) | |||||
with Grad() as grad: | |||||
grad.wrt(x, callback=save_to(x)) | |||||
refs = {} | |||||
refs = {} | |||||
def f(x): | |||||
x = x * 1 | |||||
y = x[[0, 2], [0, 2]] | |||||
refs["x"] = TensorWeakRef(x) | |||||
return y | |||||
def f(x): | |||||
x = x * 1 | |||||
y = x[[0, 2], [0, 2]] | |||||
refs["x"] = TensorWeakRef(x) | |||||
return y | |||||
y = f(x) | |||||
for _, r in refs.items(): | |||||
assert r() is None | |||||
grad(y, F.ones_like(y)) | |||||
y = f(x) | |||||
for _, r in refs.items(): | |||||
assert r() is None | |||||
grad(y, F.ones_like(y)) | |||||
np.testing.assert_equal( | np.testing.assert_equal( | ||||
np.array([[1, 0, 0], [0, 0, 0], [0, 0, 1]], dtype=np.float32), x.grad.numpy() | np.array([[1, 0, 0], [0, 0, 0], [0, 0, 1]], dtype=np.float32), x.grad.numpy() | ||||
) | ) | ||||
@@ -339,21 +333,21 @@ def test_AxisAddRemove(): | |||||
x_np = np.random.rand(1, 5).astype("float32") | x_np = np.random.rand(1, 5).astype("float32") | ||||
x = mge.Tensor(x_np) | x = mge.Tensor(x_np) | ||||
grad = Grad().wrt(x, callback=save_to(x)) | |||||
refs = {} | |||||
with Grad() as grad: | |||||
grad.wrt(x, callback=save_to(x)) | |||||
refs = {} | |||||
def f(x): | |||||
x = x * 1 | |||||
y = F.squeeze(F.expand_dims(x, 2), 0) | |||||
refs["x"] = TensorWeakRef(x) | |||||
return y | |||||
def f(x): | |||||
x = x * 1 | |||||
y = F.squeeze(F.expand_dims(x, 2), 0) | |||||
refs["x"] = TensorWeakRef(x) | |||||
return y | |||||
y = f(x) | |||||
for _, r in refs.items(): | |||||
assert r() is None | |||||
y = f(x) | |||||
for _, r in refs.items(): | |||||
assert r() is None | |||||
grad(y, F.ones_like(y)) | |||||
grad(y, F.ones_like(y)) | |||||
np.testing.assert_equal( | np.testing.assert_equal( | ||||
np.array([[1, 1, 1, 1, 1]], dtype=np.float32), x.grad.numpy() | np.array([[1, 1, 1, 1, 1]], dtype=np.float32), x.grad.numpy() | ||||
) | ) | ||||
@@ -363,10 +357,11 @@ def test_Broadcast(): | |||||
x_np = np.random.rand(3, 3, 1).astype("float32") | x_np = np.random.rand(3, 3, 1).astype("float32") | ||||
x = mge.Tensor(x_np) | x = mge.Tensor(x_np) | ||||
grad = Grad().wrt(x, callback=save_to(x)) | |||||
y = F.broadcast_to(x, (3, 3, 10)) | |||||
with Grad() as grad: | |||||
grad.wrt(x, callback=save_to(x)) | |||||
y = F.broadcast_to(x, (3, 3, 10)) | |||||
grad(y, F.ones_like(y)) | |||||
grad(y, F.ones_like(y)) | |||||
np.testing.assert_equal(np.ones((3, 3, 1), dtype=np.float32) * 10, x.grad.numpy()) | np.testing.assert_equal(np.ones((3, 3, 1), dtype=np.float32) * 10, x.grad.numpy()) | ||||
@@ -374,10 +369,11 @@ def test_interpolate_fastpath(): | |||||
x_np = np.random.rand(3, 3, 32, 32).astype("float32") | x_np = np.random.rand(3, 3, 32, 32).astype("float32") | ||||
x = mge.Tensor(x_np) | x = mge.Tensor(x_np) | ||||
grad = Grad().wrt(x, callback=save_to(x)) | |||||
y = F.vision.interpolate(x, size=(16, 16), mode="bilinear") | |||||
with Grad() as grad: | |||||
grad.wrt(x, callback=save_to(x)) | |||||
y = F.vision.interpolate(x, size=(16, 16), mode="bilinear") | |||||
grad(y, F.ones_like(y)) | |||||
grad(y, F.ones_like(y)) | |||||
np.testing.assert_equal(np.ones(x_np.shape, dtype=np.float32) / 4, x.grad.numpy()) | np.testing.assert_equal(np.ones(x_np.shape, dtype=np.float32) / 4, x.grad.numpy()) | ||||
@@ -385,10 +381,11 @@ def test_Reduce_sum(): | |||||
x_np = np.random.rand(3, 3).astype("float32") | x_np = np.random.rand(3, 3).astype("float32") | ||||
x = mge.Tensor(x_np) | x = mge.Tensor(x_np) | ||||
grad = Grad().wrt(x, callback=save_to(x)) | |||||
y = x.sum(axis=0) | |||||
with Grad() as grad: | |||||
grad.wrt(x, callback=save_to(x)) | |||||
y = x.sum(axis=0) | |||||
grad(y, F.ones_like(y)) | |||||
grad(y, F.ones_like(y)) | |||||
np.testing.assert_equal(np.ones((3, 3), dtype=np.float32), x.grad.numpy()) | np.testing.assert_equal(np.ones((3, 3), dtype=np.float32), x.grad.numpy()) | ||||
@@ -396,10 +393,11 @@ def test_Reduce_mean(): | |||||
x_np = np.random.rand(3, 3).astype("float32") | x_np = np.random.rand(3, 3).astype("float32") | ||||
x = mge.Tensor(x_np) | x = mge.Tensor(x_np) | ||||
grad = Grad().wrt(x, callback=save_to(x)) | |||||
y = x.mean(axis=0) | |||||
with Grad() as grad: | |||||
grad.wrt(x, callback=save_to(x)) | |||||
y = x.mean(axis=0) | |||||
grad(y, F.ones_like(y)) | |||||
grad(y, F.ones_like(y)) | |||||
np.testing.assert_equal(np.ones((3, 3), dtype=np.float32) / 3, x.grad.numpy()) | np.testing.assert_equal(np.ones((3, 3), dtype=np.float32) / 3, x.grad.numpy()) | ||||
@@ -407,21 +405,21 @@ def test_addAxis(): | |||||
x_np = np.random.rand(3, 3).astype("float32") | x_np = np.random.rand(3, 3).astype("float32") | ||||
x = mge.Tensor(x_np) | x = mge.Tensor(x_np) | ||||
grad = Grad().wrt(x, callback=save_to(x)) | |||||
with Grad() as grad: | |||||
grad.wrt(x, callback=save_to(x)) | |||||
refs = {} | |||||
refs = {} | |||||
def f(x): | |||||
x = x * 1 | |||||
y = F.expand_dims(x, [2, 3]) | |||||
refs["x"] = TensorWeakRef(x) | |||||
return y | |||||
def f(x): | |||||
x = x * 1 | |||||
y = F.expand_dims(x, [2, 3]) | |||||
refs["x"] = TensorWeakRef(x) | |||||
return y | |||||
y = f(x) | |||||
for _, r in refs.items(): | |||||
assert r() is None | |||||
y = f(x) | |||||
for _, r in refs.items(): | |||||
assert r() is None | |||||
grad(y, F.ones_like(y)) | |||||
grad(y, F.ones_like(y)) | |||||
np.testing.assert_equal(np.ones((3, 3), dtype=np.float32), x.grad.numpy()) | np.testing.assert_equal(np.ones((3, 3), dtype=np.float32), x.grad.numpy()) | ||||
@@ -429,21 +427,21 @@ def test_removeAxis(): | |||||
x_np = np.random.rand(3, 3, 1, 1).astype("float32") | x_np = np.random.rand(3, 3, 1, 1).astype("float32") | ||||
x = mge.Tensor(x_np) | x = mge.Tensor(x_np) | ||||
grad = Grad().wrt(x, callback=save_to(x)) | |||||
with Grad() as grad: | |||||
grad.wrt(x, callback=save_to(x)) | |||||
refs = {} | |||||
refs = {} | |||||
def f(x): | |||||
x = x * 1 | |||||
y = F.squeeze(x, [2, 3]) | |||||
refs["x"] = TensorWeakRef(x) | |||||
return y | |||||
def f(x): | |||||
x = x * 1 | |||||
y = F.squeeze(x, [2, 3]) | |||||
refs["x"] = TensorWeakRef(x) | |||||
return y | |||||
y = f(x) | |||||
for _, r in refs.items(): | |||||
assert r() is None | |||||
y = f(x) | |||||
for _, r in refs.items(): | |||||
assert r() is None | |||||
grad(y, F.ones_like(y)) | |||||
grad(y, F.ones_like(y)) | |||||
np.testing.assert_equal(np.ones((3, 3, 1, 1), dtype=np.float32), x.grad.numpy()) | np.testing.assert_equal(np.ones((3, 3, 1, 1), dtype=np.float32), x.grad.numpy()) | ||||
@@ -452,11 +450,14 @@ def test_dot(): | |||||
x = mge.Tensor(x) | x = mge.Tensor(x) | ||||
u = F.ones((2,)) | u = F.ones((2,)) | ||||
v = F.ones((2,)) | v = F.ones((2,)) | ||||
grad = Grad().wrt(x, callback=save_to(x)) | |||||
def f(x): | |||||
return F.dot(u, F.matmul(x, v)) | |||||
with Grad() as grad: | |||||
grad.wrt(x, callback=save_to(x)) | |||||
def f(x): | |||||
return F.dot(u, F.matmul(x, v)) | |||||
y = f(x) | |||||
grad(y, F.ones_like(y)) | |||||
y = f(x) | |||||
grad(y, F.ones_like(y)) | |||||
np.testing.assert_equal(np.ones((2, 2), dtype=np.float32), x.grad.numpy()) | np.testing.assert_equal(np.ones((2, 2), dtype=np.float32), x.grad.numpy()) |
@@ -267,25 +267,27 @@ def _gen_roi_inp(): | |||||
def test_roi_align(): | def test_roi_align(): | ||||
inp_feat, rois = _gen_roi_inp() | inp_feat, rois = _gen_roi_inp() | ||||
grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat)) | |||||
output_shape = (7, 7) | |||||
out_feat = F.vision.roi_align( | |||||
inp_feat, | |||||
rois, | |||||
output_shape=output_shape, | |||||
mode="average", | |||||
spatial_scale=1.0 / 4, | |||||
sample_points=2, | |||||
aligned=True, | |||||
) | |||||
assert make_shape_tuple(out_feat.shape) == ( | |||||
rois.shape[0], | |||||
inp_feat.shape[1], | |||||
*output_shape, | |||||
) | |||||
with Grad() as grad: | |||||
grad.wrt(inp_feat, callback=_save_to(inp_feat)) | |||||
output_shape = (7, 7) | |||||
out_feat = F.vision.roi_align( | |||||
inp_feat, | |||||
rois, | |||||
output_shape=output_shape, | |||||
mode="average", | |||||
spatial_scale=1.0 / 4, | |||||
sample_points=2, | |||||
aligned=True, | |||||
) | |||||
assert make_shape_tuple(out_feat.shape) == ( | |||||
rois.shape[0], | |||||
inp_feat.shape[1], | |||||
*output_shape, | |||||
) | |||||
grad(out_feat, tensor(F.ones_like(out_feat))) | |||||
grad(out_feat, tensor(F.ones_like(out_feat))) | |||||
assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape) | assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape) | ||||
@@ -307,20 +309,23 @@ def _gen_correlation(random=True, constant=1, image_shape=(2, 1, 160, 160)): | |||||
def test_correlation(): | def test_correlation(): | ||||
##test case 0 check the grad shape | ##test case 0 check the grad shape | ||||
data1, data2 = _gen_correlation() | data1, data2 = _gen_correlation() | ||||
grad = Grad().wrt(data1, callback=_save_to(data1)) | |||||
out_feat = F.vision.correlation( | |||||
data1, | |||||
data2, | |||||
kernel_size=5, | |||||
max_displacement=4, | |||||
stride1=2, | |||||
stride2=2, | |||||
pad_size=2, | |||||
is_multiply=True, | |||||
) | |||||
with Grad() as grad: | |||||
grad.wrt(data1, callback=_save_to(data1)) | |||||
out_feat = F.vision.correlation( | |||||
data1, | |||||
data2, | |||||
kernel_size=5, | |||||
max_displacement=4, | |||||
stride1=2, | |||||
stride2=2, | |||||
pad_size=2, | |||||
is_multiply=True, | |||||
) | |||||
grad(out_feat, tensor(F.ones_like(out_feat))) | |||||
grad(out_feat, tensor(F.ones_like(out_feat))) | |||||
assert make_shape_tuple(data1.grad.shape) == make_shape_tuple(data1.shape) | assert make_shape_tuple(data1.grad.shape) == make_shape_tuple(data1.shape) | ||||
##test case 1 from https://github.com/NVIDIA/flownet2-pytorch/issues/194 | ##test case 1 from https://github.com/NVIDIA/flownet2-pytorch/issues/194 | ||||
@@ -391,32 +396,36 @@ def test_correlation(): | |||||
def test_roi_pooling(): | def test_roi_pooling(): | ||||
inp_feat, rois = _gen_roi_inp() | inp_feat, rois = _gen_roi_inp() | ||||
grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat)) | |||||
output_shape = (7, 7) | |||||
out_feat = F.vision.roi_pooling( | |||||
inp_feat, rois, output_shape=output_shape, mode="max", scale=1.0 / 4, | |||||
) | |||||
assert make_shape_tuple(out_feat.shape) == ( | |||||
rois.shape[0], | |||||
inp_feat.shape[1], | |||||
*output_shape, | |||||
) | |||||
with Grad() as grad: | |||||
grad.wrt(inp_feat, callback=_save_to(inp_feat)) | |||||
output_shape = (7, 7) | |||||
out_feat = F.vision.roi_pooling( | |||||
inp_feat, rois, output_shape=output_shape, mode="max", scale=1.0 / 4, | |||||
) | |||||
assert make_shape_tuple(out_feat.shape) == ( | |||||
rois.shape[0], | |||||
inp_feat.shape[1], | |||||
*output_shape, | |||||
) | |||||
grad(out_feat, tensor(F.ones_like(out_feat))) | |||||
grad(out_feat, tensor(F.ones_like(out_feat))) | |||||
assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape) | assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape) | ||||
def test_adaptive_avg_pool2d(): | def test_adaptive_avg_pool2d(): | ||||
inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4)) | inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4)) | ||||
oshp = (2, 2) | oshp = (2, 2) | ||||
grad = Grad().wrt(inp, callback=_save_to(inp)) | |||||
outp = F.adaptive_avg_pool2d(inp, oshp,) | |||||
assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,) | |||||
np.testing.assert_equal( | |||||
outp.numpy(), np.array([[[[2.5, 4.5], [10.5, 12.5]]]], dtype=np.float32) | |||||
) | |||||
with Grad() as grad: | |||||
grad.wrt(inp, callback=_save_to(inp)) | |||||
outp = F.adaptive_avg_pool2d(inp, oshp,) | |||||
assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,) | |||||
np.testing.assert_equal( | |||||
outp.numpy(), np.array([[[[2.5, 4.5], [10.5, 12.5]]]], dtype=np.float32) | |||||
) | |||||
grad(outp, tensor(F.ones_like(outp))) | |||||
grad(outp, tensor(F.ones_like(outp))) | |||||
assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape) | assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape) | ||||
np.testing.assert_equal( | np.testing.assert_equal( | ||||
inp.grad.numpy(), | inp.grad.numpy(), | ||||
@@ -439,14 +448,16 @@ def test_adaptive_avg_pool2d(): | |||||
def test_adaptive_max_pool2d(): | def test_adaptive_max_pool2d(): | ||||
inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4)) | inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4)) | ||||
oshp = (2, 2) | oshp = (2, 2) | ||||
grad = Grad().wrt(inp, callback=_save_to(inp)) | |||||
outp = F.adaptive_max_pool2d(inp, oshp,) | |||||
assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,) | |||||
np.testing.assert_equal( | |||||
outp.numpy(), np.array([[[[5, 7], [13, 15]]]], dtype=np.float32) | |||||
) | |||||
with Grad() as grad: | |||||
grad.wrt(inp, callback=_save_to(inp)) | |||||
outp = F.adaptive_max_pool2d(inp, oshp,) | |||||
assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,) | |||||
np.testing.assert_equal( | |||||
outp.numpy(), np.array([[[[5, 7], [13, 15]]]], dtype=np.float32) | |||||
) | |||||
grad(outp, tensor(F.ones_like(outp))) | |||||
grad(outp, tensor(F.ones_like(outp))) | |||||
assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape) | assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape) | ||||
np.testing.assert_equal( | np.testing.assert_equal( | ||||
inp.grad.numpy(), | inp.grad.numpy(), | ||||
@@ -351,7 +351,7 @@ def test_expand_dims_for_scalar(): | |||||
for axis in [1, -2, (1, 2), (-2, -3)]: | for axis in [1, -2, (1, 2), (-2, -3)]: | ||||
np.testing.assert_raises(np.AxisError, np.expand_dims, x, axis) | np.testing.assert_raises(np.AxisError, np.expand_dims, x, axis) | ||||
np.testing.assert_raises(AssertionError, F.expand_dims, xx, axis) | |||||
np.testing.assert_raises(RuntimeError, F.expand_dims, xx, axis) | |||||
@pytest.mark.parametrize("is_varnode", [True, False]) | @pytest.mark.parametrize("is_varnode", [True, False]) | ||||
@@ -9,6 +9,7 @@ | |||||
import inspect | import inspect | ||||
import io | import io | ||||
import itertools | import itertools | ||||
import random | |||||
from tempfile import mkstemp | from tempfile import mkstemp | ||||
import numpy as np | import numpy as np | ||||
@@ -25,7 +26,7 @@ from megengine.core.ops import builtin as ops | |||||
from megengine.core.ops.builtin import Elemwise | from megengine.core.ops.builtin import Elemwise | ||||
from megengine.core.tensor.utils import isscalar | from megengine.core.tensor.utils import isscalar | ||||
from megengine.functional import exp, log | from megengine.functional import exp, log | ||||
from megengine.jit import GraphOptimizationConfig, exclude_from_trace, trace | |||||
from megengine.jit import GraphOptimizationConfig, TraceError, exclude_from_trace, trace | |||||
from megengine.module import Module | from megengine.module import Module | ||||
from megengine.random import normal, uniform | from megengine.random import normal, uniform | ||||
from megengine.utils.naming import AutoNaming | from megengine.utils.naming import AutoNaming | ||||
@@ -464,36 +465,92 @@ def test_trace_warp_perspective(): | |||||
f(x, M) | f(x, M) | ||||
def test_raise_on_trace(): | |||||
step_count = 0 | |||||
catch_count = 0 | |||||
bad_step = 10 | |||||
@pytest.mark.parametrize( | |||||
"normal_expr, mismatch_expr, reason", | |||||
[ | |||||
("a + b + c", "a + b - c", "operator mismatch"), | |||||
("a + b + 1", "a + b + 2", "tensors not equals"), | |||||
("((a + b), (b + c))[0]", "a + b", "mismature end"), | |||||
("a + b + c", "c + (a + b)", "expect internal node, got external"), | |||||
("c + (a + b)", "a + b + c", "expect external node, got internal"), | |||||
("a + b + c", "a + b + c + c", "too many instructions"), | |||||
("((a + b), (b + c))[1]", "((a + b), (b + c))[0]", "data unreadable"), | |||||
("((a + b), (b + c))[1] + a", "((a + b), (b + c))[0] + a", "input id mismatch"), | |||||
], | |||||
) | |||||
def test_trace_mismatch(normal_expr, mismatch_expr, reason): | |||||
a = tensor([1, 2, 3, 4]) | |||||
b = tensor([5, 6, 7, 8]) | |||||
c = tensor([9, 0, 1, 2]) | |||||
mismatch = False | |||||
@trace(symbolic=True) | |||||
def fn(a, b, c): | |||||
if not mismatch: | |||||
result = eval(normal_expr) | |||||
else: | |||||
result = eval(mismatch_expr) | |||||
return result | |||||
for i in range(20): | |||||
try: | |||||
d = fn(a, b, c) | |||||
except TraceError as e: | |||||
assert mismatch | |||||
assert str(e) == "trace error because {}".format(reason) | |||||
except: | |||||
pytest.fail("unexpected trace error") | |||||
else: | |||||
assert not mismatch | |||||
np.testing.assert_equal(d.numpy(), eval(normal_expr).numpy()) | |||||
mismatch = random.random() > 0.8 | |||||
class CatchMe(Exception): | |||||
pass | |||||
def test_exception_in_trace(): | |||||
a = tensor([1, 2, 3, 4]) | a = tensor([1, 2, 3, 4]) | ||||
b = tensor([5, 6, 7, 8]) | b = tensor([5, 6, 7, 8]) | ||||
c = tensor([9, 0, 1, 2]) | c = tensor([9, 0, 1, 2]) | ||||
@trace | |||||
def add_abc(a, b, c): | |||||
ps = a + b | |||||
result = ps + c | |||||
if step_count == bad_step: | |||||
raise CatchMe("catch me") | |||||
mismatch = False | |||||
exc = Exception() | |||||
@trace(symbolic=True) | |||||
def fn(a, b, c): | |||||
result = a + b | |||||
if not mismatch: | |||||
result += c | |||||
else: | |||||
raise exc | |||||
return result | return result | ||||
for i in range(100): | |||||
for i in range(20): | |||||
try: | try: | ||||
d = add_abc(a, b, c) | |||||
except CatchMe as e: | |||||
catch_count += 1 | |||||
d = fn(a, b, c) | |||||
except TraceError as e: | |||||
pytest.fail("unexpected trace error") | |||||
except Exception as e: | |||||
assert mismatch | |||||
assert e is exc | |||||
else: | else: | ||||
assert not mismatch | |||||
np.testing.assert_equal(d.numpy(), (a + b + c).numpy()) | np.testing.assert_equal(d.numpy(), (a + b + c).numpy()) | ||||
step_count += 1 | |||||
mismatch = random.random() > 0.8 | |||||
assert catch_count == 1 | |||||
def test_graph_error(): | |||||
a = tensor(np.arange(8).reshape((2, 4))) | |||||
b = tensor(np.arange(8).reshape((2, 4))) | |||||
@trace(symbolic=True) | |||||
def fn(a, b): | |||||
return a + b | |||||
fn(a, b) | |||||
with pytest.raises(RuntimeError): | |||||
fn(a, b.transpose()) | |||||
fn(a, b) | |||||
@pytest.mark.parametrize("trace_mode", [False, True]) | @pytest.mark.parametrize("trace_mode", [False, True]) | ||||
@@ -653,9 +710,10 @@ def test_trace_jit_config(): | |||||
x = tensor(2) | x = tensor(2) | ||||
y = func(x) | y = func(x) | ||||
func._compile() | |||||
y = func(x) | |||||
# func._compile() | |||||
options = func._graph.options | |||||
options = func._trace.options | |||||
mapping = {None: 0, False: 1, True: 2} | mapping = {None: 0, False: 1, True: 2} | ||||
assert options.graph_opt.jit == 0 | assert options.graph_opt.jit == 0 | ||||
assert options.graph_opt.jit_config.fuse_dimshuffle == mapping[fuse_dimshuffle] | assert options.graph_opt.jit_config.fuse_dimshuffle == mapping[fuse_dimshuffle] | ||||
@@ -82,9 +82,10 @@ def test_tqt(): | |||||
x = mge.tensor(x, dtype="float32") | x = mge.tensor(x, dtype="float32") | ||||
s = mge.tensor(s, dtype="float32") | s = mge.tensor(s, dtype="float32") | ||||
g_y = mge.tensor(g_y, dtype="float32") | g_y = mge.tensor(g_y, dtype="float32") | ||||
grad = Grad().wrt(x, s, callback=cb) | |||||
y = tqt_forward(-127, 127, x, s) | |||||
grad(y, g_y) | |||||
with Grad() as grad: | |||||
grad.wrt(x, s, callback=cb) | |||||
y = tqt_forward(-127, 127, x, s) | |||||
grad(y, g_y) | |||||
g_x, g_s = g | g_x, g_s = g | ||||
np.testing.assert_allclose(y.numpy(), y_np, rtol=1e-5, atol=1e-5) | np.testing.assert_allclose(y.numpy(), y_np, rtol=1e-5, atol=1e-5) | ||||
@@ -131,14 +132,16 @@ def test_fakequant(): | |||||
# test backward | # test backward | ||||
x = tensor(inp_data, dtype=np.float32) | x = tensor(inp_data, dtype=np.float32) | ||||
grad = Grad().wrt(x, callback=_save_to(x)) | |||||
y = fake_quant_tensor(x, qparams) | |||||
grad(y, tensor(F.ones_like(x))) | |||||
with Grad() as grad: | |||||
grad.wrt(x, callback=_save_to(x)) | |||||
y = fake_quant_tensor(x, qparams) | |||||
grad(y, tensor(F.ones_like(x))) | |||||
x1 = tensor(inp_data, dtype=np.float32) | x1 = tensor(inp_data, dtype=np.float32) | ||||
grad = Grad().wrt(x1, callback=_save_to(x1)) | |||||
y1 = fake_quant_tensor_gt(x1, scale, zero_point, qmin, qmax) | |||||
grad(y1, tensor(F.ones_like(x1))) | |||||
with Grad() as grad: | |||||
grad.wrt(x1, callback=_save_to(x1)) | |||||
y1 = fake_quant_tensor_gt(x1, scale, zero_point, qmin, qmax) | |||||
grad(y1, tensor(F.ones_like(x1))) | |||||
assert np.allclose(x.grad.numpy(), x1.grad.numpy()) | assert np.allclose(x.grad.numpy(), x1.grad.numpy()) | ||||
assert make_shape_tuple(x.grad.shape) == make_shape_tuple(x1.grad.shape) | assert make_shape_tuple(x.grad.shape) == make_shape_tuple(x1.grad.shape) | ||||
@@ -237,9 +240,10 @@ def test_lsq(): | |||||
grad_s = mge.tensor(grad_s, dtype="float32") | grad_s = mge.tensor(grad_s, dtype="float32") | ||||
g_y = mge.tensor(g_y, dtype="float32") | g_y = mge.tensor(g_y, dtype="float32") | ||||
grad = Grad().wrt(x, s, callback=cb) | |||||
y = lsq_forward(-127, 127, x, s, zero_point, grad_s) | |||||
grad(y, g_y) | |||||
with Grad() as grad: | |||||
grad.wrt(x, s, callback=cb) | |||||
y = lsq_forward(-127, 127, x, s, zero_point, grad_s) | |||||
grad(y, g_y) | |||||
g_x, g_s = g | g_x, g_s = g | ||||
np.testing.assert_allclose(y.numpy(), y_np, rtol=1e-7, atol=1e-7) | np.testing.assert_allclose(y.numpy(), y_np, rtol=1e-7, atol=1e-7) | ||||
@@ -430,9 +430,10 @@ def test_ShuffleRNG(): | |||||
n, m = 6, 3 | n, m = 6, 3 | ||||
arr = np.arange(n * m) | arr = np.arange(n * m) | ||||
out0 = Tensor(arr, dtype="float32") | out0 = Tensor(arr, dtype="float32") | ||||
grad = Grad().wrt(out0, callback=cb) | |||||
random.shuffle(out0) | |||||
grad(out0, F.ones_like(out0)) | |||||
with Grad() as grad: | |||||
grad.wrt(out0, callback=cb) | |||||
random.shuffle(out0) | |||||
grad(out0, F.ones_like(out0)) | |||||
m1 = RNG(seed=111, device="xpu0") | m1 = RNG(seed=111, device="xpu0") | ||||
m2 = RNG(seed=111, device="xpu1") | m2 = RNG(seed=111, device="xpu1") | ||||
m3 = RNG(seed=222, device="xpu0") | m3 = RNG(seed=222, device="xpu0") | ||||