GitOrigin-RevId: 32dd49a23a
tags/v1.8.0
@@ -28,9 +28,6 @@ class AttachSpec: | |||
__slots__ = "tensor", "callbacks" | |||
_global_priority = 0 | |||
class GradManager: | |||
r"""GradManager computes gradients or more generally, vector-Jacobian product, by reverse mode | |||
automatic differentiation (a.k.a. back propagation). | |||
@@ -127,7 +124,6 @@ class GradManager: | |||
self._grad = None | |||
self._after_backward_callback = [] | |||
self._gradients = {} | |||
self._priority = None | |||
def attached_tensors(self): | |||
r"""Return attached tensor list from :meth:`attach`.""" | |||
@@ -299,31 +295,25 @@ class GradManager: | |||
tensor.grad = grad | |||
else: | |||
tensor.grad += grad | |||
if tensor._isscalar() and tensor.grad is not None: | |||
tensor.grad._setscalar() | |||
finally: | |||
self.release() | |||
backwarding_grad_manager = cache | |||
set_option("record_computing_path", 1) | |||
pop_scope("backward") | |||
set_option("record_computing_path", 1) | |||
pop_scope("backward") | |||
def record(self): | |||
r"""Start recording operations | |||
After this call, you will be able to call :meth:`backward`. | |||
""" | |||
global _global_priority | |||
if self._recording: | |||
raise RuntimeError("already recording") | |||
grad = Grad() | |||
self._recording = True | |||
self._grad = grad | |||
grad.__enter__() | |||
for spec in self._attach_specs.values(): | |||
self._do_record(spec) | |||
if self._priority is None: | |||
grad._priority = _global_priority | |||
_global_priority -= 1 | |||
grad.__enter__() | |||
def _do_record(self, spec): | |||
tensor = spec.tensor() | |||
@@ -331,6 +321,8 @@ class GradManager: | |||
return | |||
def callback(grad, callbacks=spec.callbacks): | |||
from ..functional import ones_like | |||
for cb in callbacks: | |||
grad = cb(tensor, grad) | |||
self._gradients[id(tensor)] = grad | |||
@@ -343,14 +335,11 @@ class GradManager: | |||
After this call, you will not be able to call :meth:`backward`. | |||
""" | |||
global _global_priority | |||
if self._grad is not None: | |||
self._grad.__exit__(None, None, None) | |||
self._grad = None | |||
self._recording = False | |||
self._gradients = dict() | |||
if self._priority is None: | |||
_global_priority += 1 | |||
def __enter__(self): | |||
self.record() | |||
@@ -382,15 +371,14 @@ class GradManagerGroup: | |||
__ror__ = merge_with | |||
def __enter__(self): | |||
global _global_priority | |||
_global_priority += 1 | |||
Grad.stack.append([]) | |||
Grad.begin_group() | |||
for gm in self._gms: | |||
gm._priority = _global_priority | |||
gm.record() | |||
assert gm._grad is not None | |||
Grad.end_group() | |||
def __exit__(self, exc_type, exc_val, exc_tb): | |||
global _global_priority | |||
_global_priority -= 1 | |||
for gm in self._gms: | |||
for gm in reversed(self._gms): | |||
gm.release() | |||
gm._priority = None | |||
assert gm._grad is None |
@@ -6,17 +6,9 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import functools | |||
import heapq | |||
import itertools | |||
import typing | |||
import weakref | |||
import numpy as np | |||
from .._imperative_rt import core2, ops | |||
from ..ops.builtin import Elemwise, OpDef, RemoteSend | |||
from ..ops.special import Const | |||
from .._imperative_rt import core2 | |||
_grad_count = 0 | |||
_grad_manager_dict = weakref.WeakValueDictionary() | |||
@@ -36,6 +28,10 @@ class GradKey(core2.GradKey): | |||
class Grad: | |||
stack = [] | |||
grouping = False | |||
key2grad = weakref.WeakValueDictionary() | |||
def __init__(self, name=None): | |||
global _grad_count | |||
if name is None: | |||
@@ -43,15 +39,9 @@ class Grad: | |||
_grad_count += 1 | |||
self._refkeeper = [] | |||
self._impl = GradKey(name) | |||
Grad.key2grad[self._impl] = self | |||
_grad_manager_dict[self._name] = self | |||
@property | |||
def _priority(self): | |||
return self._impl.priority | |||
@_priority.setter | |||
def _priority(self, priority): | |||
self._impl.priority = priority | |||
self._group = [weakref.ref(self)] | |||
@property | |||
def _name(self): | |||
@@ -70,33 +60,80 @@ class Grad: | |||
if not isinstance(ys, Sequence): | |||
ys = [ys] | |||
if not isinstance(dys, Sequence): | |||
dys = [dys] | |||
group = [ref() for ref in self._group] | |||
for grad in group: | |||
if grad is self: | |||
continue | |||
grad.suppress() | |||
self._impl.backward(ys, dys) | |||
for grad in group: | |||
if grad is self: | |||
continue | |||
grad.resume() | |||
self._refkeeper = None | |||
return None | |||
def __enter__(self): | |||
ref = weakref.ref(self) | |||
self._impl.enter() | |||
if Grad.grouping: | |||
group = Grad.stack[-1] | |||
self._group = group | |||
group.append(ref) | |||
else: | |||
Grad.stack.append(self._group) | |||
return self | |||
def __exit__(self, _1, _2, _3): | |||
self._impl.exit() | |||
self._refkeeper = None | |||
del self._impl | |||
class Function(ops.PyOpBase): | |||
del Grad.key2grad[self._impl] | |||
self._impl = None | |||
self._group.remove(weakref.ref(self)) | |||
if len(self._group) == 0: | |||
Grad.stack.remove(self._group) | |||
@staticmethod | |||
def begin_group(): | |||
assert not Grad.grouping | |||
Grad.grouping = True | |||
@staticmethod | |||
def end_group(): | |||
group = Grad.stack[-1] | |||
assert len(group) > 0 | |||
assert Grad.grouping | |||
Grad.grouping = False | |||
def suppress(self): | |||
if self._impl is not None: | |||
self._impl.suppress() | |||
def resume(self): | |||
if self._impl is not None: | |||
self._impl.resume() | |||
class Function: | |||
r"""Defines a block of operations with customizable differentiation. | |||
The computation should be defined in ``forward`` method, with gradient | |||
computation defined in ``backward`` method. | |||
Each instance of ``Function`` should be used only once during forwardding. | |||
Examples: | |||
.. code-block:: | |||
class Sigmoid(Function): | |||
def forward(self, x): | |||
y = 1 / (1 + F.exp(-x)) | |||
@@ -115,7 +152,7 @@ class Function(ops.PyOpBase): | |||
Returns: | |||
a tuple of Tensor or a single Tensor. | |||
Note: | |||
* This method should return a tuple of Tensor or a single Tensor representing the output | |||
of the function. | |||
@@ -128,7 +165,7 @@ class Function(ops.PyOpBase): | |||
Args: | |||
output_grads: gradients of outputs that are returned by :meth:`forward`. | |||
Note: | |||
* In case when some tensors of outputs are not related to loss function, the corresponding | |||
values in ``output_grads`` would be ``None``. | |||
@@ -148,10 +185,40 @@ class Function(ops.PyOpBase): | |||
return self._default_rule(*args), self.backward | |||
def __call__(self, *args): | |||
ret = core2.apply(self, *args) | |||
for arg in args: | |||
if not isinstance(arg, core2.Tensor): | |||
raise TypeError( | |||
"op Function expect type Tensor as inputs, got {}".format(type(arg)) | |||
) | |||
grad_key = core2.get_grad_key(args) | |||
if grad_key is None: | |||
return self._default_rule(*args) | |||
grad = Grad.key2grad[grad_key] | |||
group = [ref() for ref in grad._group] | |||
for grad in group: | |||
grad.suppress() | |||
outputs, backward = self._grad_rule(*args) | |||
for grad in reversed(group): | |||
grad.resume() | |||
def normalized_backward(*output_grads): | |||
input_grads = backward(*output_grads) | |||
if isinstance(input_grads, core2.Tensor) or input_grads is None: | |||
input_grads = (input_grads,) | |||
return input_grads | |||
if self.__single_output: | |||
(ret,) = ret | |||
return ret | |||
outputs = (outputs,) | |||
for grad in reversed(group): | |||
if grad._impl is None: | |||
continue | |||
outputs = core2.set_grad(grad._impl, normalized_backward, args, outputs) | |||
if self.__single_output: | |||
(outputs,) = outputs | |||
return outputs | |||
def __getstate__(self): | |||
return self.__dict__ | |||
@@ -26,7 +26,6 @@ from .utils import ( | |||
convert_inputs, | |||
isscalar, | |||
make_shape_tuple, | |||
setscalar, | |||
) | |||
_ElwMod = builtin.Elemwise.Mode | |||
@@ -34,14 +33,7 @@ _ElwMod = builtin.Elemwise.Mode | |||
def _elwise_apply(args, mode): | |||
op = builtin.Elemwise(mode) | |||
_isscalar = True | |||
for i in args: | |||
if isscalar(i) == False: | |||
_isscalar = False | |||
break | |||
(result,) = apply(op, *args) | |||
if _isscalar: | |||
setscalar(result) | |||
return result | |||
@@ -203,8 +195,6 @@ def _remove_axis(inp: Tensor, axis) -> Tensor: | |||
op = builtin.RemoveAxis(axis=axis) | |||
(result,) = apply(op, inp) | |||
if len(axis) == inp.ndim: | |||
setscalar(result) | |||
return result | |||
@@ -221,6 +211,7 @@ def _reduce(mode): | |||
op = builtin.Reduce(mode=mode, axis=0) | |||
(result,) = apply(op, data) | |||
result = _remove_axis(result, 0) | |||
elif isinstance(axis, collections.abc.Iterable): | |||
axis = _normalize_axis(self.ndim, axis, reverse=True) | |||
for ai in axis: | |||
@@ -239,8 +230,6 @@ def _reduce(mode): | |||
if self.dtype == np.bool_: | |||
if mode in ["min", "max"]: | |||
result = result.astype("bool") | |||
if axis is None or self.ndim == 1: | |||
setscalar(result) | |||
return result | |||
return f | |||
@@ -457,7 +446,6 @@ class ArrayMethodMixin(abc.ABC): | |||
len(args) == 0 | |||
), "transpose for scalar does not accept additional args" | |||
ret = self.to(self.device) | |||
setscalar(ret) | |||
return ret | |||
if not args: | |||
args = range(self.ndim)[::-1] | |||
@@ -111,7 +111,6 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||
if not isinstance(tuple_val, tuple): | |||
tuple_val = (tuple_val,) | |||
ndim_indexed = 0 | |||
ndim_indexed_scalar = 0 | |||
for i in tuple_val: | |||
if not i is Ellipsis: | |||
ndim_indexed += ( | |||
@@ -119,14 +118,6 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||
if hasattr(i, "dtype") and i.dtype == np.bool_ and hasattr(i, "ndim") | |||
else 1 | |||
) | |||
if isscalar(i): | |||
ndim_indexed_scalar += 1 | |||
ret_scalar = False | |||
try: | |||
ret_scalar = ndim_indexed_scalar == inp.ndim | |||
except ValueError: | |||
# inp.ndim is unknown | |||
pass | |||
else: | |||
if ndim_indexed > inp.ndim: | |||
raise IndexError( | |||
@@ -221,7 +212,7 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||
items.append(item) | |||
if new_axes: | |||
raise IndexError("newaxis is not allowed here") | |||
return inp, tensors, items, use_subtensor, ret_scalar | |||
return inp, tensors, items, use_subtensor | |||
def try_condtake(tensor, index): | |||
@@ -247,14 +238,12 @@ def getitem(tensor, index): | |||
try_result = try_condtake(tensor, index) | |||
if len(try_result) == 2: | |||
return try_result[0] | |||
tensor, tensors, items, use_subtensor, ret_scalar = unpack_getitem(tensor, index) | |||
tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index) | |||
if use_subtensor: | |||
op = builtin.Subtensor(items=items) | |||
else: | |||
op = builtin.IndexingMultiAxisVec(items=items) | |||
(result,) = apply(op, tensor, *tensors) | |||
if ret_scalar: | |||
result._setscalar() | |||
return result | |||
@@ -266,7 +255,7 @@ def setitem(tensor, index, value): | |||
tensor = tensor.reshape(-1) | |||
if not isinstance(value, (Tensor, SymbolVar)): | |||
(value,) = Const(value, dtype=tensor.dtype, device=tensor.device)(tensor) | |||
tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) | |||
tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index) | |||
if use_subtensor: | |||
op = builtin.Subtensor(items=items) | |||
else: | |||
@@ -17,6 +17,7 @@ import numpy as np | |||
from .. import _imperative_rt | |||
from .._imperative_rt import GraphOptimizeOptions, SerializationFormat | |||
from .._imperative_rt.core2 import apply | |||
from .._wrap import as_device | |||
from ..ops.builtin import OpDef | |||
@@ -126,9 +127,8 @@ class Graph(_imperative_rt.ComputingGraph): | |||
class VarNode: | |||
def __init__(self, node: _imperative_rt.VarNode, isscalar=False): | |||
def __init__(self, node: _imperative_rt.VarNode): | |||
self._node = node | |||
self._isscalar = isscalar | |||
if hasattr(self.graph, "_var_cache"): | |||
self.graph._var_cache[node] = self | |||
@@ -530,9 +530,6 @@ def _unwrap(x): | |||
def apply_normal_varnode(op: OpDef, *args: VarNode): | |||
# for PyOp like RemoteSend/Recv | |||
if getattr(op, "op", None): | |||
op = op.op | |||
outputs = _imperative_rt.invoke_op(op, _unwrap(args)) | |||
return _wrap(outputs) | |||
@@ -51,10 +51,7 @@ def concatenate(inputs, axis=0, *, device=None): | |||
def astype(x, dtype): | |||
dtype = np.dtype(dtype) | |||
if not is_dtype_equal(x.dtype, dtype): | |||
isscalar = x._isscalar() | |||
(x,) = apply(builtin.TypeCvt(dtype=dtype), x) | |||
if isscalar: | |||
x._setscalar() | |||
return x | |||
@@ -129,13 +126,6 @@ def isscalar(x): | |||
return np.isscalar(x) | |||
def setscalar(x): | |||
if isinstance(x, (Tensor, SymbolVar)): | |||
x._setscalar() | |||
else: | |||
raise NotImplementedError("Unsupport type {}".format(type(x))) | |||
def astensor1d(x, *reference, dtype=None, device=None): | |||
"""Convert something to 1D tensor. Support following types | |||
@@ -237,6 +227,7 @@ for name, mode in [ | |||
("**", "pow"), | |||
("max", "max"), | |||
("additive", "add"), | |||
("exp", "EXP"), | |||
]: | |||
_opr_map[(name, 2)] = builtin.Elemwise(mode=mode) | |||
@@ -13,7 +13,7 @@ import numpy as np | |||
from ..core._imperative_rt.core2 import apply | |||
from ..core.autodiff.grad import Function, _grad_manager_dict | |||
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | |||
from ..core.tensor.utils import isscalar, setscalar | |||
from ..core.tensor.utils import isscalar | |||
from ..device import get_default_device, what_is_xpu | |||
from ..tensor import Tensor | |||
from . import group | |||
@@ -72,15 +72,6 @@ def collective_comm(inp, mode, group, device): | |||
) | |||
(result,) = apply(op, inp) | |||
# assume all workers have homogeneous shape | |||
if mode in ( | |||
CollectiveComm.Mode.REDUCE_SUM, | |||
CollectiveComm.Mode.BROADCAST, | |||
CollectiveComm.Mode.ALL_REDUCE_SUM, | |||
CollectiveComm.Mode.ALL_REDUCE_MAX, | |||
CollectiveComm.Mode.ALL_REDUCE_MIN, | |||
): | |||
if isscalar(inp): | |||
setscalar(result) | |||
return result | |||
@@ -190,8 +181,7 @@ def reduce_sum( | |||
# Rank 0 # output: None | |||
# Rank 1 # output: Tensor([1]) | |||
""" | |||
op = _ReduceSum(group, device) | |||
(out,) = apply(op, inp) | |||
out = _ReduceSum(group, device)(inp) | |||
if group.rank == 0: | |||
return out | |||
@@ -258,8 +248,7 @@ def broadcast( | |||
_bcast_tracer_state(group, inp) | |||
op = _Broadcast(group, device) | |||
(out,) = apply(op, inp) | |||
out = _Broadcast(group, device)(inp) | |||
return out | |||
@@ -604,8 +593,7 @@ def gather( | |||
inp.shape | |||
) | |||
op = _Gather(group, device) | |||
(out,) = apply(op, inp) | |||
out = _Gather(group, device)(inp) | |||
if group.rank == 0: | |||
if axis == 0: | |||
@@ -708,8 +696,7 @@ def scatter( | |||
+ [_ for _ in range(axis + 1, inp.ndim + 1)] | |||
) | |||
inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape) | |||
op = _Scatter(group, device) | |||
(out,) = apply(op, inp) | |||
out = _Scatter(group, device)(inp) | |||
return out | |||
@@ -832,7 +819,7 @@ class _RemoteRecv(Function): | |||
self.op = op | |||
def forward(self, dummy): | |||
return apply(self.op, dummy) | |||
return apply(self.op, dummy)[0] | |||
def backward(self, grad): | |||
get_client().bcast_val(grad is not None, self.op.key, 2) | |||
@@ -871,7 +858,7 @@ def remote_send(inp: Tensor, dest_rank: int): | |||
op.addr, op.port = get_mm_server_addr() | |||
op.rank_to = dest_rank | |||
op.backend = _backend() | |||
(out,) = apply(_RemoteSend(op), inp) | |||
out = _RemoteSend(op)(inp) | |||
_save_output_for_autodiff(inp, out) | |||
@@ -912,11 +899,6 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor | |||
inp = Tensor(0, device=device) | |||
_bcast_tracer_state(group, inp) | |||
_isscalar = False | |||
if len(shape) == 0: | |||
shape = (1,) | |||
_isscalar = True | |||
op = RemoteRecv() | |||
op.key = group.key | |||
op.cn = device | |||
@@ -926,7 +908,5 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor | |||
op.rank_from = src_rank | |||
op.backend = _backend() | |||
(ret,) = apply(_RemoteRecv(op), inp) | |||
if _isscalar: | |||
setscalar(ret) | |||
ret = _RemoteRecv(op)(inp) | |||
return ret |
@@ -67,9 +67,6 @@ def param_pack_split(inp: Tensor, offsets: list, shapes: list): | |||
op.offsets = offsets | |||
op.shapes = [s or (1,) for s in shapes] | |||
outputs = apply(op, inp) | |||
for s, x in zip(shapes, outputs): | |||
if not s: | |||
x._setscalar() | |||
return outputs | |||
@@ -1,25 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ..core._imperative_rt.core2 import ( | |||
set_allow_higher_order_directive as _set_allow_higher_order_directive, | |||
) | |||
__all__ = [ | |||
"enable_higher_order_directive", | |||
"disable_higher_order_directive", | |||
] | |||
def enable_higher_order_directive(): | |||
_set_allow_higher_order_directive(True) | |||
def disable_higher_order_directive(): | |||
_set_allow_higher_order_directive(False) |
@@ -12,8 +12,5 @@ from ..core.ops.builtin import InplaceAdd | |||
def _inplace_add_(dest, delta, alpha, beta): | |||
isscalar = dest._isscalar() | |||
dest._reset(apply(InplaceAdd(), dest, delta, alpha, beta)[0]) | |||
if isscalar: | |||
dest._setscalar() | |||
return dest |
@@ -19,7 +19,7 @@ from ..core.ops import builtin | |||
from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt | |||
from ..core.ops.special import Const | |||
from ..core.tensor import amp | |||
from ..core.tensor.utils import _normalize_axis, cast_tensors, setscalar, subgraph | |||
from ..core.tensor.utils import _normalize_axis, cast_tensors, subgraph | |||
from ..jit import exclude_from_trace | |||
from ..tensor import Tensor | |||
from ..utils.deprecation import deprecated_kwargs_default | |||
@@ -1149,7 +1149,6 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor: | |||
inp1.ndim <= 1 and inp2.ndim <= 1 | |||
), "Input tensors for dot must be 1-dimensional or scalar" | |||
(result,) = apply(op, inp1, inp2) | |||
setscalar(result) | |||
return result | |||
@@ -1200,5 +1199,4 @@ def _check_non_finite(inps: Iterable[Tensor], scale=1.0) -> Tensor: | |||
for i in range(len(inps)): | |||
inps[i]._reset(oups[i]) | |||
out._setscalar() | |||
return out |
@@ -35,7 +35,6 @@ from ..core.tensor.utils import ( | |||
cast_tensors, | |||
convert_single_value, | |||
make_shape_tuple, | |||
setscalar, | |||
subgraph, | |||
) | |||
from ..device import get_default_device | |||
@@ -972,13 +972,6 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: | |||
) | |||
axis = sorted(axis) | |||
assert axis, "axis could not be empty" | |||
if inp._isscalar(): | |||
assert axis[0] == 0, "invalid axis {} for ndim 0".format(axis[0]) | |||
if len(axis) == 1: | |||
inp = copy(inp, device=None) | |||
inp._unsetscalar() | |||
return inp | |||
axis = axis[1:] | |||
op = builtin.AddAxis(axis=axis) | |||
(result,) = apply(op, inp) | |||
return result | |||
@@ -1164,8 +1157,6 @@ def repeat(inp: Tensor, repeats: int, axis: Optional[int] = None): | |||
if axis is None: | |||
inp = inp.reshape(-1) # flatten | |||
axis = 0 | |||
if inp._isscalar(): | |||
inp._unsetscalar() | |||
shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device) | |||
# assume inp.ndim is not changed during trace | |||
max_axis = len(shape) - 1 | |||
@@ -6,19 +6,7 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ..core._imperative_rt.core2 import ( | |||
set_cpp_apply_const_with_tracing, | |||
set_cpp_apply_with_tracing, | |||
) | |||
from .dtr_config import DTRConfig | |||
from .graph_opt_config import GraphOptimizationConfig | |||
from .sublinear_memory_config import SublinearMemoryConfig | |||
from .tracing import ( | |||
apply_const_with_tracing, | |||
apply_with_tracing, | |||
exclude_from_trace, | |||
trace, | |||
) | |||
set_cpp_apply_with_tracing(apply_with_tracing) | |||
set_cpp_apply_const_with_tracing(apply_const_with_tracing) | |||
from .tracing import TraceError, exclude_from_trace, trace |
@@ -111,6 +111,7 @@ class Module(metaclass=ABCMeta): | |||
# used for profiler and automatic naming | |||
self._name = None | |||
self._short_name = None | |||
@abstractmethod | |||
def forward(self, inputs): | |||
@@ -137,7 +138,7 @@ class Module(metaclass=ABCMeta): | |||
return HookHandler(self._forward_hooks, hook) | |||
def __call__(self, *inputs, **kwargs): | |||
AutoNaming.push_scope(self.name if self.name is not None else self._name) | |||
AutoNaming.push_scope(self.name if self.name is not None else self._short_name) | |||
for hook in self._forward_pre_hooks.values(): | |||
modified_inputs = hook(self, inputs) | |||
if modified_inputs is not None: | |||
@@ -641,15 +642,43 @@ class Module(metaclass=ABCMeta): | |||
else: | |||
if modules is not None and name in modules: | |||
modules.remove(name) | |||
for k, v in _expand_structure(name, value): | |||
if not v._name: | |||
v._name = k | |||
elif v._name != k: | |||
def append_name(prefix, name): | |||
if prefix is None or prefix == "": | |||
return name | |||
return prefix + "." + name | |||
def set_name(parent, prefix, name, obj): | |||
if isinstance(obj, Tensor): | |||
assert obj.name is not None | |||
if obj.name != "": | |||
name = obj.name | |||
full_name = append_name(prefix, name) | |||
if obj._short_name and obj._short_name != name: | |||
logger.warning( | |||
"try setting the submodule `{}` to `{}`'s new attribute `{}`, its name `{}` will remain unchanged".format( | |||
type(v), type(self), k, v._name | |||
obj._short_name, type(parent), name, obj._short_name | |||
) | |||
) | |||
return | |||
if isinstance(obj, Tensor): | |||
obj._prefix = prefix | |||
obj._name = full_name | |||
obj._short_name = name | |||
obj._set_name(obj._name) | |||
return obj._name | |||
elif isinstance(obj, Module): | |||
obj._name = full_name | |||
obj._short_name = name | |||
for k, v in obj._flatten(recursive=False, with_key=True): | |||
set_name(obj, full_name, k, v) | |||
return obj._name | |||
else: | |||
assert False | |||
for k, v in _expand_structure(name, value): | |||
prefix = self._name if self._name else self.name | |||
set_name(self, prefix, k, v) | |||
super().__setattr__(name, value) | |||
def __delattr__(self, name: str): | |||
@@ -14,6 +14,7 @@ from numpy.random import MT19937 | |||
from .. import Tensor | |||
from ..core._imperative_rt.core2 import apply | |||
from ..core._imperative_rt.core2 import sync as _sync | |||
from ..core._imperative_rt.ops import delete_rng_handle as _delete_rng_handle | |||
from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed | |||
from ..core._imperative_rt.ops import ( | |||
@@ -650,6 +651,10 @@ class RNG: | |||
def __del__(self): | |||
if self._handle != 0: | |||
# RNG op might execute after handle released due to async dispatch, so | |||
# we need sync before delete a handle to avoid memory leak or | |||
# use-after-free | |||
_sync() | |||
_delete_rng_handle(self._handle) | |||
@@ -12,7 +12,7 @@ import numpy as np | |||
from .core._imperative_rt import CompNode | |||
from .core._imperative_rt.core2 import Tensor as _Tensor | |||
from .core._imperative_rt.core2 import apply | |||
from .core._imperative_rt.core2 import apply, set_py_tensor_type | |||
from .core._trace_option import use_symbolic_shape | |||
from .core._wrap import as_device | |||
from .core.ops.builtin import Copy, GetVarShape | |||
@@ -20,7 +20,6 @@ from .core.tensor.array_method import ArrayMethodMixin | |||
from .device import _valid_device, get_default_device | |||
from .logger import get_logger | |||
from .utils.deprecation import deprecated | |||
from .utils.naming import AutoNaming | |||
logger = get_logger(__name__) | |||
@@ -40,6 +39,10 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
grad = None | |||
dmap_callback = None | |||
_qparams = None | |||
_custom_name = "" | |||
_name = None | |||
_short_name = None | |||
_prefix = None | |||
def __new__( | |||
cls, | |||
@@ -81,9 +84,15 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
device: str = None, | |||
is_const: bool = False, | |||
no_cache: bool = False, | |||
name: str = None, | |||
name: str = "", | |||
): | |||
pass | |||
if name is None: | |||
name = "" | |||
self._custom_name = name | |||
self._name = name | |||
self._short_name = name | |||
self._set_name(self._name) | |||
self._prefix = None | |||
@property | |||
def shape(self) -> Union[tuple, "Tensor"]: | |||
@@ -151,12 +160,13 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
@property | |||
def name(self): | |||
return self.c_name | |||
return self._custom_name | |||
@name.setter | |||
def name(self, name): | |||
self.c_name = name | |||
AutoNaming.record_var_name(self._mixin_handle, name) | |||
self._custom_name = name | |||
self._name = self._prefix + "." + name if self._prefix else name | |||
self._set_name(self._name) | |||
@deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0") | |||
def set_value(self, value): | |||
@@ -224,6 +234,9 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
self._qparams = qparams | |||
set_py_tensor_type(Tensor) | |||
tensor = Tensor | |||
@@ -6,7 +6,6 @@ | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ..core._imperative_rt.core2 import set_cpp_apply_module_trace | |||
from . import compat | |||
from ._passes import optimize | |||
from .pytree import register_supported_type | |||
@@ -14,14 +13,12 @@ from .tm_config import disable_default_checker, enable_expr_checker | |||
from .traced_module import ( | |||
TracedModule, | |||
_register_all_builtin_module, | |||
cpp_apply_module_trace, | |||
register_as_builtin, | |||
trace_module, | |||
wrap, | |||
) | |||
_register_all_builtin_module() | |||
set_cpp_apply_module_trace(cpp_apply_module_trace) | |||
__all__ = [ | |||
"register_as_builtin", | |||
@@ -13,7 +13,6 @@ import numpy as np | |||
from ..core._imperative_rt.core2 import apply | |||
from ..core._imperative_rt.ops import ROIAlign, ROIPooling | |||
from ..core.ops.builtin import Copy | |||
from ..core.tensor.utils import isscalar, setscalar | |||
from ..tensor import Tensor | |||
from .tm_config import _exclude_from_trace | |||
@@ -70,8 +69,6 @@ class TracedModuleChecker: | |||
self.current_node2values()[node] = apply( | |||
Copy(comp_node=value.device), value | |||
)[0] | |||
if isscalar(value): | |||
setscalar(self.current_node2values()[node]) | |||
def check_apply_special_cases(self, opdef, num_outputs): | |||
indexs = list(range(num_outputs)) | |||
@@ -20,6 +20,7 @@ from ..core._imperative_rt.core2 import Tensor as RawTensor | |||
from ..core._imperative_rt.core2 import ( | |||
apply, | |||
is_tracing_module, | |||
set_module_trace_hook, | |||
set_module_tracing, | |||
unset_module_tracing, | |||
) | |||
@@ -605,8 +606,7 @@ class Apply(Expr): | |||
def apply_module_trace_hook(cls, opdef, *inputs): | |||
for i in inputs: | |||
node = NodeMixin.get(i, None) | |||
if node is None: # capture as constant | |||
NodeMixin.wrap_safe(i, Constant.make(i)) | |||
assert node is not None | |||
if isinstance(opdef, FakeQuant): | |||
inp_nodes = [NodeMixin.get(inputs[0])] | |||
@@ -805,3 +805,12 @@ class Constant(Expr): | |||
if isinstance(v, _ModuleState): | |||
state[k] = v.to_module() | |||
self.__dict__.update(state) | |||
def _module_trace_capture(value): | |||
node = Constant.make(value) | |||
NodeMixin.wrap_safe(value, node) | |||
return node | |||
set_module_trace_hook(Apply.apply_module_trace_hook) |
@@ -101,9 +101,7 @@ BUILTIN_TENSOR_WRAP_METHOD = [ | |||
"requires_grad", | |||
"_reset", | |||
"_isscalar", | |||
"_setscalar", | |||
"_tuple_shape", | |||
"_unsetscalar", | |||
] | |||
@@ -43,7 +43,6 @@ from ..core._imperative_rt.core2 import ( | |||
) | |||
from ..core._trace_option import set_symbolic_shape | |||
from ..core.ops.builtin import Copy | |||
from ..core.tensor.utils import isscalar, setscalar | |||
from ..module import Module | |||
from ..module import external as MExternal | |||
from ..module.qat import QATModule | |||
@@ -1295,12 +1294,9 @@ def _wrapped_function(orig_func): | |||
return orig_func(*args, **kwargs) | |||
if isinstance(args[1], RawTensor): | |||
node = NodeMixin.get(inputs[1]) | |||
is_scalar = isscalar(inputs[1]) | |||
inputs[1] = apply( | |||
Copy(comp_node=inputs[1].device), Tensor(inputs[1]) | |||
)[0] | |||
if is_scalar: | |||
setscalar(inputs[1]) | |||
# copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, | |||
# which will cause they have same _NodeMixin__node in tracing. | |||
NodeMixin.wrap_safe(inputs[1], node) | |||
@@ -2468,8 +2464,8 @@ def trace_module( | |||
try: | |||
net_name = mod._name if mod._name else mod.__class__.__name__ | |||
use_sym_shape = set_symbolic_shape(True) | |||
set_module_tracing() | |||
set_active_module_tracer(module_tracer(_wrapped_function)) | |||
set_module_tracing() | |||
for cls in [Expr, Node]: | |||
cls._set_next_id(0) | |||
with active_module_tracer().patcher: | |||
@@ -2518,9 +2514,9 @@ def trace_module( | |||
return traced_mod | |||
finally: | |||
set_symbolic_shape(use_sym_shape) | |||
set_active_module_tracer(None) | |||
unset_module_tracing() | |||
for t in mod.tensors(recursive=True): | |||
NodeMixin.clear_node(t) | |||
for t in inputs: | |||
NodeMixin.clear_node(t) | |||
set_active_module_tracer(None) |
@@ -137,6 +137,11 @@ class Profiler(ContextDecorator): | |||
get_logger().info("process {} generating {}".format(self._pid, format)) | |||
self._dump_callback(path, format) | |||
get_logger().info("profiling results written to {}".format(path)) | |||
if os.path.getsize(path) > 64 * 1024 * 1024: | |||
get_logger().warning( | |||
"profiling results too large, maybe you are profiling multi iters," | |||
"consider attach profiler in each iter separately" | |||
) | |||
self._dump_callback = None | |||
_living_profilers.remove(self) | |||
@@ -9,9 +9,8 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma GCC diagnostic ignored "-Wmissing-field-initializers" | |||
#include "./grad.h" | |||
#include "megbrain/imperative/backward_graph_opt.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/imperative/proxy_graph_detail.h" | |||
@@ -19,465 +18,19 @@ | |||
#include "range/v3/all.hpp" | |||
#include "./transformation.h" | |||
namespace py = pybind11; | |||
namespace views = ranges::views; | |||
namespace mgb::imperative::python { | |||
using scoped_disable = ApplyContext::scoped_disable; | |||
using Flags = Tensor::Flags; | |||
namespace { | |||
struct GradSlotWeakPtr { | |||
std::weak_ptr<GradFn> grad_fn; | |||
size_t idx; | |||
}; | |||
std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph( | |||
ApplyContext& ctx, const apply_result_t& outputs) { | |||
// hash | |||
using OptimizedBackwardGraphCache = OpMethResultCache< | |||
std::shared_ptr<OptimizedBackwardGraphResult>, SmallVector<bool>>; | |||
thread_local OptimizedBackwardGraphCache cache; | |||
decltype(cache)::key_t cache_key{ctx.op}; | |||
SmallVector<LogicalTensorDesc>& input_descs = cache_key.inputs; | |||
SmallVector<bool>& input_requires_grad = std::get<0>(cache_key.extras); | |||
input_descs.resize(ctx.nargs); | |||
input_requires_grad.resize(ctx.nargs); | |||
for (size_t i = 0; i < ctx.nargs; ++i) { | |||
input_descs[i].layout.dtype = ctx.args[i]->dtype(); | |||
input_descs[i].comp_node = ctx.args[i]->comp_node(); | |||
input_requires_grad[i] = python::input_requires_grad(ctx, i); | |||
} | |||
auto iter = cache.find(cache_key); | |||
if (iter != cache.end()) { | |||
return iter->second; | |||
} | |||
// slow path | |||
SmallVector<bool> output_has_grad(outputs.size(), true); | |||
std::shared_ptr<OptimizedBackwardGraphResult> ret; | |||
auto bg = OpDef::make_backward_graph( | |||
*ctx.op, input_descs, input_requires_grad, output_has_grad); | |||
if (!bg.graph.empty()) { | |||
ret = std::make_shared<OptimizedBackwardGraphResult>(bg); | |||
} | |||
cache.emplace(cache_key, ret); | |||
return ret; | |||
std::unordered_map<std::shared_ptr<GradKey>, GradKeyWrapper*> grad_key_map; | |||
} | |||
struct BackwardGraphWithClosure { | |||
std::shared_ptr<OptimizedBackwardGraphResult> backward_graph; | |||
SmallVector<std::shared_ptr<Tensor>> closure; | |||
size_t output_mask_offset; | |||
size_t grad_mask_offset; | |||
BackwardGraphWithClosure( | |||
std::shared_ptr<OptimizedBackwardGraphResult> backward_graph_, | |||
ApplyContext& ctx, const apply_result_t& outputs) | |||
: backward_graph(backward_graph_), | |||
output_mask_offset(ctx.nargs), | |||
grad_mask_offset(ctx.nargs + outputs.size()) { | |||
// save_for_backward[0:nargs]: | |||
// whether input is kept for backward | |||
// | |||
// save_for_backward[nargs:nargs+outputs.size()]: | |||
// whether output is kept for backward | |||
// | |||
// save_for_backward[-outputs.size():]: | |||
// whether gradient of output can propagate to any input | |||
// | |||
// Example: | |||
// perform c = a * b, with a.requires_grad == True and | |||
// b.requires_grad == False, save_for_backward = [0, 1, 0, 1] | |||
auto& save_for_backward = backward_graph->save_for_backward; | |||
mgb_assert(save_for_backward.size() == ctx.nargs + 2 * outputs.size()); | |||
size_t count = std::count_if( | |||
save_for_backward.begin(), save_for_backward.end(), ranges::identity{}); | |||
if (!backward_graph->precomp.empty()) { | |||
auto&& irng = ranges::span(ctx.args, ctx.nargs); | |||
auto&& orng = views::transform(outputs, [](auto&& i) { return i.get(); }); | |||
auto precomp = apply(backward_graph->precomp, views::concat(irng, orng)); | |||
closure.reserve(precomp.size() + count); | |||
std::copy(precomp.begin(), precomp.end(), std::back_inserter(closure)); | |||
} else { | |||
closure.reserve(count); | |||
} | |||
for (size_t i = 0; i < ctx.nargs; ++i) { | |||
if (save_for_backward[i]) { | |||
closure.push_back(ctx.args[i]->shared_from_this()); | |||
} | |||
} | |||
for (size_t i = 0; i < outputs.size(); ++i) { | |||
if (save_for_backward[ctx.nargs + i]) { | |||
closure.push_back(outputs[i]); | |||
} | |||
} | |||
} | |||
template <typename T, typename R> | |||
void operator()(BackwardContext&, T&& grads, R&& receiver) { | |||
Tensor* args[closure.size() + grads.size()]; | |||
size_t nargs = 0; | |||
for (auto&& t : closure) { | |||
args[nargs++] = t.get(); | |||
} | |||
bool null_grad = false; | |||
for (size_t i = 0; i < grads.size(); ++i) { | |||
if (backward_graph->save_for_backward[grad_mask_offset + i]) { | |||
if (grads[i]) { | |||
if (null_grad) { | |||
PyErr_SetString(PyExc_NotImplementedError, "report to devs"); | |||
throw py::error_already_set(); | |||
} | |||
args[nargs++] = grads[i]; | |||
} else { | |||
null_grad = true; | |||
} | |||
} | |||
} | |||
if (null_grad) | |||
return; | |||
auto igrads = apply(backward_graph->backward, args, nargs); | |||
auto&& it = igrads.begin(); | |||
for (auto [i, p] : views::enumerate(backward_graph->input_has_grad)) { | |||
if (p) { | |||
receiver(i, std::move(*it)); | |||
++it; | |||
} | |||
} | |||
} | |||
bool input_has_grad(size_t i) { return backward_graph->input_has_grad[i]; } | |||
bool output_requires_grad(size_t i) { | |||
return backward_graph->save_for_backward[grad_mask_offset + i]; | |||
} | |||
bool output_captured(size_t i) { | |||
return backward_graph->save_for_backward[output_mask_offset + i]; | |||
} | |||
}; | |||
struct PythonBackward { | |||
py::object pyfunc; | |||
size_t input_size; | |||
PythonBackward(py::object f, size_t nin) : pyfunc(f), input_size(nin) {} | |||
template <typename T, typename R> | |||
void operator()(BackwardContext& ctx, T&& grads, R&& receiver) { | |||
auto args = py::tuple(grads.size()); | |||
for (size_t i = 0; i < grads.size(); ++i) { | |||
auto&& g = grads[i]; | |||
args[i] = g ? ctx.wrap_tensor(g) : py::none(); | |||
} | |||
auto input_grads = py::reinterpret_steal<py::object>( | |||
PyObject_Call(pyfunc.ptr(), args.ptr(), nullptr)); | |||
if (!input_grads) | |||
throw py::error_already_set(); | |||
if (input_grads.is_none()) | |||
return; | |||
if (auto* tw = TensorWrapper::try_cast(input_grads.ptr())) { | |||
if (input_size != 1) { | |||
throw py::value_error( | |||
"custom grad rule returned wrong number of grads"); | |||
} | |||
if (!ctx.pytype) { | |||
ctx.pytype = Py_TYPE(input_grads.ptr()); | |||
} | |||
receiver(0, tw->m_tensor); | |||
return; | |||
} | |||
if (py::len(input_grads) != input_size) { | |||
throw py::value_error("custom grad rule returned wrong number of grads"); | |||
} | |||
for (auto [i, g] : views::enumerate(input_grads)) { | |||
if (g.is_none()) | |||
continue; | |||
auto* tw = TensorWrapper::try_cast(g.ptr()); | |||
if (!tw) { | |||
throw py::type_error("custom grad rule returned non-tensor"); | |||
} | |||
if (!ctx.pytype) { | |||
ctx.pytype = Py_TYPE(g.ptr()); | |||
} | |||
receiver(i, tw->m_tensor); | |||
} | |||
} | |||
static constexpr bool input_has_grad(size_t) { return true; } | |||
static constexpr bool output_requires_grad(size_t) { return true; } | |||
static constexpr bool output_captured(size_t) { return true; } | |||
}; | |||
} // namespace | |||
struct GradProducerRecord : intrusive_list::Node<GradProducerRecord> { | |||
using Base = intrusive_list::Node<GradProducerRecord>; | |||
GradProducerRecord() = default; | |||
GradProducerRecord(GradProducerRecord::head_t& head) | |||
: Base(intrusive_list::after_t{}, head) {} | |||
// GradProducerRecord(GradProducerRecord&&) = default; | |||
// GradProducerRecord& operator=(GradProducerRecord&) = default; | |||
// GradProducerRecord& operator=(GradProducerRecord&&) = default; | |||
}; | |||
struct GradSlot { | |||
std::shared_ptr<Tensor> grad; | |||
py::object callback; | |||
GradProducerRecord::head_t producer_head; | |||
}; | |||
struct GradSlotProducerPtr : GradSlotPtr { | |||
GradProducerRecord producer_record; | |||
GradSlotProducerPtr() = default; | |||
GradSlotProducerPtr(GradInfo& info) | |||
: GradSlotPtr(info), producer_record(info->producer_head) {} | |||
}; | |||
struct GradFn : std::enable_shared_from_this<GradFn> { | |||
static MemPool<GradFn> pool; | |||
std::weak_ptr<GradKey> key; | |||
// slots for receiving and accumulating grads | |||
// same length as outputs (of forward op) | |||
SmallVector<GradSlot> slots; | |||
// where to send and accumulate grads | |||
// same length as inputs (of forward op) | |||
SmallVector<GradSlotProducerPtr> dsts; | |||
// encapsules actual function to compute gradient | |||
std::variant< | |||
std::monostate, BackwardGraphWithClosure, PythonBackward, CustomBackward> | |||
backward; | |||
// a flag used during backward | |||
bool in_ref_keeper = false; | |||
static void deleter(GradFn* ptr) { pool.free(ptr); } | |||
static std::shared_ptr<GradFn> make() { | |||
return std::shared_ptr<GradFn>(pool.alloc(), &deleter); | |||
} | |||
void clear() { | |||
key.reset(); | |||
slots.clear(); | |||
dsts.clear(); | |||
backward.emplace<std::monostate>(); | |||
} | |||
}; | |||
GradSlotPtr::operator bool() const { | |||
return bool(grad_fn); | |||
} | |||
GradSlot* GradSlotPtr::operator->() { | |||
return &grad_fn->slots[idx]; | |||
} | |||
namespace { | |||
class GradFnHelper { | |||
std::shared_ptr<GradFn> grad_fn; | |||
GradFn* get() { | |||
if (!grad_fn) { | |||
grad_fn = std::make_shared<GradFn>(); | |||
} | |||
return grad_fn.get(); | |||
} | |||
friend apply_result_t imperative::python::apply_grad(ApplyContext&); | |||
public: | |||
template <typename T, typename... Args> | |||
auto& emplace(Args&&... args) { | |||
return get()->backward.emplace<T>(std::forward<Args>(args)...); | |||
} | |||
void reset() { grad_fn = nullptr; } | |||
}; | |||
apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { | |||
// copy inputs first, or trace will make InputNodes for each usage | |||
ApplyContext ctx_dup = ctx; | |||
SmallVector<std::shared_ptr<Tensor>> inputs_copy; | |||
SmallVector<Tensor*> inputs_copy_weak; | |||
for (size_t i = 0; i < ctx.nargs; ++i) { | |||
Tensor* input = ctx.args[i]; | |||
inputs_copy.push_back(python::apply(FastpathCopy::make(), input)[0]); | |||
inputs_copy_weak.push_back(inputs_copy.back().get()); | |||
inputs_copy.back()->m_grad_info_dict = ctx.args[i]->m_grad_info_dict; | |||
if (input->m_flags & Flags::GRAD) { | |||
inputs_copy.back()->m_flags |= Flags::GRAD; | |||
} | |||
} | |||
ctx_dup.args = inputs_copy_weak.data(); | |||
auto outputs = apply(ctx_dup); | |||
auto backward_graph = make_backward_graph(ctx_dup, outputs); | |||
if (!backward_graph) { | |||
return outputs; | |||
} | |||
ret_grad_fn.emplace<BackwardGraphWithClosure>( | |||
std::move(backward_graph), ctx_dup, outputs); | |||
return outputs; | |||
} | |||
apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { | |||
auto* op = ctx.op->try_cast_final<GenericPyOp>(); | |||
py::tuple pyin(ctx.nargs); | |||
for (size_t i = 0; i < ctx.nargs; ++i) { | |||
pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this()); | |||
} | |||
auto grad_rule = py::getattr(op->obj, "_grad_rule"); | |||
auto pyret = py::reinterpret_steal<py::object>( | |||
PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr)); | |||
if (!pyret) | |||
throw py::error_already_set(); | |||
auto [outputs, backward] = py::cast<std::tuple<py::object, py::function>>(pyret); | |||
ret_grad_fn.emplace<PythonBackward>(std::move(backward), ctx.nargs); | |||
if (auto* tw = TensorWrapper::try_cast(outputs.ptr())) { | |||
return {tw->m_tensor}; | |||
} | |||
apply_result_t ret; | |||
ret.reserve(py::len(outputs)); | |||
for (auto&& i : outputs) { | |||
auto* tw = TensorWrapper::try_cast(i.ptr()); | |||
mgb_assert(tw); | |||
ret.push_back(tw->m_tensor); | |||
} | |||
return ret; | |||
} | |||
} // namespace | |||
apply_result_t apply_grad(ApplyContext& ctx) { | |||
std::unordered_set<std::shared_ptr<GradKey>> grad_keys; | |||
for (size_t i = 0; i < ctx.nargs; ++i) { | |||
auto* tensor = ctx.args[i]; | |||
if (!tensor->m_grad_info_dict.empty()) { | |||
size_t grad_cnt = 0; | |||
for (auto&& grad_info : tensor->m_grad_info_dict) { | |||
auto input_grad_key = grad_info.grad_fn->key.lock(); | |||
if (input_grad_key && input_grad_key->active && | |||
!input_grad_key->is_blocked()) { | |||
grad_keys.insert(input_grad_key); | |||
grad_cnt++; | |||
} | |||
} | |||
if (!grad_cnt) { | |||
tensor->m_flags &= ~Flags::GRAD; | |||
} | |||
} else { | |||
tensor->m_flags &= ~Flags::GRAD; | |||
} | |||
} | |||
ctx.flags &= ~Flags::GRAD; | |||
if (grad_keys.empty()) { | |||
return apply(ctx); | |||
} else if (grad_keys.size() > 1 && !GradKey::allow_higher_order_directive) { | |||
PyErr_SetString( | |||
PyExc_NotImplementedError, | |||
"second order directive not enabled, please call " | |||
"'megengine.experimental.enable_higher_order_directive'"); | |||
throw pyext17::py_err_set(); | |||
} | |||
GradFnHelper grad_fn_holder; | |||
auto outputs = [&]() { | |||
auto _ = scoped_disable(Flags::GRAD); | |||
if (ctx.op->same_type<GenericPyOp>()) { | |||
return python_grad_rule(ctx, grad_fn_holder); | |||
} | |||
auto&& registry = grad_rule_registry(); | |||
auto&& it = registry.find(ctx.op->dyn_typeinfo()); | |||
if (it != registry.end()) { | |||
auto&& maker = grad_fn_holder.emplace<CustomBackward>().maker(ctx); | |||
if (auto ret = it->second(ctx, maker)) { | |||
maker.finalize(); | |||
return *ret; | |||
} | |||
grad_fn_holder.reset(); | |||
} | |||
return backward_graph_grad_rule(ctx, grad_fn_holder); | |||
}(); | |||
if (!grad_fn_holder.grad_fn) { | |||
return outputs; | |||
} | |||
for (auto&& grad_key : grad_keys) { | |||
auto grad_fn = std::make_shared<GradFn>(); | |||
grad_fn->backward = grad_fn_holder.grad_fn->backward; | |||
grad_fn->key = grad_key; | |||
grad_fn->slots.resize(outputs.size()); | |||
grad_fn->dsts.reserve(ctx.nargs); | |||
std::visit( | |||
[&](auto& backward) { | |||
using T = std::decay_t<decltype(backward)>; | |||
if constexpr (std::is_same_v<T, std::monostate>) { | |||
mgb_assert(0); | |||
} else { | |||
for (size_t i = 0; i < ctx.nargs; ++i) { | |||
if (backward.input_has_grad(i) && | |||
input_requires_grad(ctx, i) && | |||
ctx.args[i]->m_grad_info_dict.count(grad_key.get())) { | |||
auto& input_grad_info = | |||
ctx.args[i]->m_grad_info_dict.at( | |||
grad_key.get()); | |||
grad_fn->dsts.emplace_back(input_grad_info); | |||
// register as grad producer | |||
grad_fn->dsts.back().producer_record.insert_after( | |||
input_grad_info->producer_head); | |||
} else { | |||
grad_fn->dsts.emplace_back(); | |||
} | |||
} | |||
for (size_t i = 0; i < outputs.size(); ++i) { | |||
if (backward.output_requires_grad(i)) { | |||
if (backward.output_captured(i)) { | |||
// avoid reference cycle [Tensor <-> GradFn] | |||
static std::shared_ptr<OpDef> op = | |||
std::make_shared<FastpathCopy>(); | |||
outputs[i] = python::apply(op, outputs[i])[0]; | |||
} | |||
// populate grad info of output tensor | |||
auto& grad_info = | |||
outputs[i]->m_grad_info_dict[grad_key.get()]; | |||
grad_info.grad_fn = grad_fn; | |||
grad_info.idx = i; | |||
grad_info.insert_after(grad_key->free_vars_head); | |||
outputs[i]->m_flags |= Flags::GRAD; | |||
} | |||
} | |||
} | |||
}, | |||
grad_fn->backward); | |||
// record forward history | |||
grad_key->tape.emplace_back(grad_fn); | |||
} | |||
return outputs; | |||
} | |||
PyObject* GradKeyWrapper::get_priority() { | |||
return py::cast(m_key->priority).release().ptr(); | |||
} | |||
void GradKeyWrapper::set_priority(pybind11::handle priority) { | |||
m_key->priority = py::cast<int>(priority); | |||
GradKeyWrapper::GradKeyWrapper() : m_key(std::make_shared<GradKey>()) { | |||
grad_key_map[m_key] = this; | |||
} | |||
void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) { | |||
@@ -488,157 +41,59 @@ void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) { | |||
if (!tw) { | |||
throw py::type_error("argument 1 must be Tensor"); | |||
} | |||
auto* tensor = tw->m_tensor.get(); | |||
py::object callback; | |||
if (args[1] != Py_None) { | |||
callback = py::reinterpret_borrow<py::object>(args[1]); | |||
} | |||
m_key->attach(tensor, std::move(callback)); | |||
} | |||
//! GradKey is weakly refered by tensor->m_grad_info.grad_fn->key after attach | |||
void GradKey::attach(Tensor* tensor, pybind11::object callback) { | |||
if (!active) { | |||
throw py::value_error("grad key finalized"); | |||
} | |||
if (tensor->m_grad_info_dict.count(this)) { | |||
if (tensor->m_grad_info_dict.at(this)->callback) { | |||
throw py::value_error("callback already set on this tensor"); | |||
GenericFunction generic_callback = | |||
[=](Span<ValueRef> inputs) -> std::vector<ValueRef> { | |||
mgb_assert(inputs.size() == 1); | |||
if (callback) { | |||
callback(TensorWrapper::make(py_tensor_type, inputs[0])); | |||
} | |||
} else { | |||
auto& grad_info = tensor->m_grad_info_dict[this]; | |||
grad_info.idx = 0; | |||
auto& grad_fn = grad_info.grad_fn; | |||
grad_fn = std::make_shared<GradFn>(); | |||
grad_fn->key = shared_from_this(); | |||
grad_fn->slots.resize(1); | |||
grad_info.insert_after(free_vars_head); | |||
tensor->m_flags |= Flags::GRAD; | |||
} | |||
tensor->m_grad_info_dict.at(this).grad_fn->slots[0].callback = std::move(callback); | |||
} | |||
template <typename T> | |||
void accum_grad(std::shared_ptr<Tensor>& grad, T&& delta) { | |||
if (!grad) { | |||
grad = std::forward<T>(delta); | |||
return; | |||
} | |||
static std::shared_ptr<OpDef> op = | |||
std::shared_ptr<OpDef>(new Elemwise(Elemwise::Mode::ADD)); | |||
grad = apply(op, grad, std::forward<T>(delta))[0]; | |||
return {}; | |||
}; | |||
tw->m_tensor->reset(imperative::apply( | |||
AttachGrad(m_key), tw->m_tensor->data(), | |||
FunctionValue::make(generic_callback))[0]); | |||
} | |||
void GradKey::backward( | |||
std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) { | |||
if (!active) { | |||
throw py::value_error("finalized"); | |||
void GradKeyWrapper::backward(GradKeyWrapper* self, py::list tensors, py::list grads) { | |||
std::vector<ValueRef> args; | |||
mgb_assert(tensors.size() == grads.size()); | |||
for (auto&& tensor : tensors) { | |||
args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data()); | |||
} | |||
if (tensors.size() != grads.size()) { | |||
throw py::value_error("tensor and grad size mismatch"); | |||
for (auto&& grad : grads) { | |||
args.push_back(TensorWrapper::try_cast(grad.ptr())->m_tensor->data()); | |||
} | |||
// this GradKey is marked inactive here | |||
active = false; | |||
struct CleanupGuard { | |||
GradKey* owner; | |||
size_t priority_backup; | |||
CleanupGuard(GradKey* this_) : owner(this_) { | |||
priority_backup = sm_min_priority; | |||
sm_min_priority = owner->priority + 1; | |||
} | |||
~CleanupGuard() { | |||
owner->cleanup(); | |||
sm_min_priority = priority_backup; | |||
} | |||
} _cleanup_guard(this); | |||
if (tape.empty()) | |||
return; | |||
BackwardContext bctx; | |||
if (!grads.empty()) { | |||
bctx.pytype = Py_TYPE(grads[0]->self().ptr()); | |||
} | |||
for (size_t i = 0; i < tensors.size(); ++i) { | |||
if (tensors[i]->m_tensor->m_grad_info_dict.count(this) == 0) { | |||
continue; | |||
} | |||
auto& grad_info = tensors[i]->m_tensor->m_grad_info_dict.at(this); | |||
grad_info->grad = grads[i]->m_tensor; | |||
} | |||
std::vector<std::shared_ptr<GradFn>> ref_keeper; | |||
ref_keeper.reserve(tape.size()); | |||
// back-propagation in reverse order | |||
for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) { | |||
auto&& grad_fn = tape[k].lock(); | |||
if (!grad_fn) | |||
continue; | |||
auto grad_receiver = [&](size_t i, auto&& g) { | |||
auto& dst = grad_fn->dsts[i]; | |||
if (dst) { | |||
accum_grad(dst->grad, std::forward<decltype(g)>(g)); | |||
} | |||
}; | |||
std::visit( | |||
[&](auto&& backward) { | |||
using T = std::decay_t<decltype(backward)>; | |||
if constexpr (std::is_same_v<T, std::monostate>) { | |||
mgb_assert(0); | |||
} else { | |||
auto&& grads = views::transform( | |||
grad_fn->slots, | |||
[](auto&& slot) { return slot.grad.get(); }); | |||
backward( | |||
bctx, std::forward<decltype(grads)>(grads), | |||
grad_receiver); | |||
} | |||
}, | |||
grad_fn->backward); | |||
for (auto&& dst : grad_fn->dsts) { | |||
if (!dst.grad_fn) | |||
continue; | |||
if (!dst.grad_fn->in_ref_keeper) { | |||
// after grad_fn is cleared, refcnt of subsequent grad_fn | |||
// could drop to 0 | |||
dst.grad_fn->in_ref_keeper = true; | |||
ref_keeper.push_back(dst.grad_fn); | |||
} | |||
if (!dst.producer_record.next && dst->callback && dst->grad) { | |||
// I'm the last grad producer, invoke callback | |||
dst->callback(bctx.wrap_tensor(dst->grad)); | |||
} | |||
} | |||
grad_fn->clear(); | |||
} // finish tape loop | |||
imperative::apply(GradBackward(self->m_key), {args.data(), args.size()}); | |||
} | |||
void GradKey::cleanup() { | |||
active = false; | |||
tape.clear(); | |||
for (intrusive_list::Iterator it(free_vars_head); it;) { | |||
it->grad_fn.reset(); | |||
(it++)->unlink(); | |||
pybind11::function GradKeyWrapper::get_backward_closure( | |||
GradKeyWrapper* self, py::list tensors) { | |||
std::vector<ValueRef> args; | |||
for (auto&& tensor : tensors) { | |||
args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data()); | |||
} | |||
} | |||
void GradKeyWrapper::backward( | |||
std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) { | |||
m_key->backward(std::move(tensors), std::move(grads)); | |||
auto closure = imperative::apply(GetBackwardColsure(self->m_key), args)[0] | |||
.as<FunctionValue>(); | |||
auto py_function = [closure](std::vector<TensorWrapper*> tensors) { | |||
std::vector<ValueRef> args; | |||
for (auto* tw : tensors) { | |||
args.push_back(tw->m_tensor->data()); | |||
} | |||
(*closure)(args); | |||
}; | |||
return pybind11::cpp_function(py_function); | |||
} | |||
PyObject* GradKeyWrapper::get_name() { | |||
return py::cast(m_key->name).release().ptr(); | |||
return py::cast(m_key->name()).release().ptr(); | |||
} | |||
void GradKeyWrapper::set_name(py::handle name) { | |||
m_key->name = py::cast<std::string>(name); | |||
m_key->name(py::cast<std::string>(name)); | |||
} | |||
PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) { | |||
@@ -651,60 +106,39 @@ PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) { | |||
PyErr_SetString(PyExc_TypeError, "expect Tensor"); | |||
return nullptr; | |||
} | |||
if (tw->m_tensor->m_grad_info_dict.count(m_key.get())) { | |||
if (imperative::apply(IsAttachedTo(m_key), tw->m_tensor->data())[0] | |||
.cast<BoolValue>()) { | |||
Py_RETURN_TRUE; | |||
} | |||
Py_RETURN_FALSE; | |||
} | |||
int GradKey::sm_min_priority = std::numeric_limits<int>::min(); | |||
GradKey::~GradKey() { | |||
cleanup(); | |||
void GradKeyWrapper::enter() { | |||
m_transformation = std::make_shared<GradTransformation>(m_key); | |||
TransformationManager::get_instance().register_at<TransformationManager::Grad>( | |||
m_transformation); | |||
} | |||
std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry() { | |||
static std::unordered_map<Typeinfo*, GradRuleFn> registry; | |||
return registry; | |||
void GradKeyWrapper::exit() { | |||
TransformationManager::get_instance().unregister<TransformationManager::Grad>( | |||
m_transformation); | |||
m_transformation.reset(); | |||
} | |||
void GradInfoCollection::_shrink() { | |||
auto pred = [](GradInfo& info) { | |||
return !(info.grad_fn) || info.grad_fn->key.expired(); | |||
}; | |||
auto iter = std::remove_if(m_storage.begin(), m_storage.end(), pred); | |||
m_storage.erase(iter, m_storage.end()); | |||
void GradKeyWrapper::suppress() { | |||
m_transformation->suppress(); | |||
} | |||
bool GradInfoCollection::contains(GradKey* key) { | |||
_shrink(); | |||
for (auto&& grad_info : m_storage) { | |||
if (grad_info.grad_fn->key.lock().get() == key) { | |||
return true; | |||
} | |||
} | |||
return false; | |||
void GradKeyWrapper::resume() { | |||
m_transformation->resume(); | |||
} | |||
GradInfo& GradInfoCollection::operator[](GradKey* key) { | |||
_shrink(); | |||
for (auto&& grad_info : m_storage) { | |||
if (grad_info.grad_fn->key.lock().get() == key) { | |||
return grad_info; | |||
} | |||
} | |||
m_storage.emplace_back(); | |||
return m_storage.back(); | |||
GradKeyWrapper* GradKeyWrapper::get(std::shared_ptr<GradKey> key) { | |||
return grad_key_map.at(key); | |||
} | |||
GradInfo& GradInfoCollection::at(GradKey* key) { | |||
_shrink(); | |||
for (auto&& grad_info : m_storage) { | |||
if (grad_info.grad_fn->key.lock().get() == key) { | |||
return grad_info; | |||
} | |||
} | |||
mgb_assert(false); | |||
GradKeyWrapper::~GradKeyWrapper() { | |||
grad_key_map.erase(m_key); | |||
} | |||
} // namespace mgb::imperative::python |
@@ -12,166 +12,40 @@ | |||
#pragma once | |||
#include "./tensor.h" | |||
#include "megbrain/imperative/ops/utility.h" | |||
#include "megbrain/imperative/transformations/grad.h" | |||
#include "megbrain/utils/small_vector.h" | |||
#include <megbrain/utils/small_vector.h> | |||
#include <memory> | |||
#include <optional> | |||
namespace mgb::imperative::python { | |||
apply_result_t apply_grad(ApplyContext& ctx); | |||
struct GradKey : std::enable_shared_from_this<GradKey>, NonCopyableObj { | |||
std::string name; | |||
bool active = true; | |||
GradInfo::head_t free_vars_head; | |||
std::vector<std::weak_ptr<GradFn>> tape; | |||
int priority = 0; | |||
~GradKey(); | |||
void attach(Tensor* tensor, pybind11::object callback); | |||
void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>); | |||
void cleanup(); | |||
bool is_blocked() const { return priority < sm_min_priority; } | |||
inline static bool allow_higher_order_directive = false; | |||
private: | |||
static int sm_min_priority; | |||
}; | |||
struct GradKeyWrapper { | |||
struct GradKeyWrapper : NonCopyableObj { | |||
using wrap_t = pyext17::wrap<GradKeyWrapper>; | |||
static constexpr auto tp_name = pybind11::detail::_("GradKey"); | |||
std::shared_ptr<GradKey> m_key; | |||
std::shared_ptr<GradTransformation> m_transformation; | |||
inline GradKeyWrapper() : m_key(std::make_shared<GradKey>()) {} | |||
GradKeyWrapper(); | |||
PyObject* get_name(); | |||
void set_name(pybind11::handle name); | |||
PyObject* get_priority(); | |||
void set_priority(pybind11::handle priority); | |||
void attach(PyObject* const* args, size_t nargs); | |||
void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>); | |||
static void backward(GradKeyWrapper* self, pybind11::list, pybind11::list); | |||
static pybind11::function get_backward_closure( | |||
GradKeyWrapper* self, pybind11::list); | |||
PyObject* is_attached_to(PyObject* const* args, size_t nargs); | |||
void enter(); | |||
void exit(); | |||
void suppress(); | |||
void resume(); | |||
static GradKeyWrapper* get(std::shared_ptr<GradKey> key); | |||
~GradKeyWrapper(); | |||
}; | |||
struct BackwardContext { | |||
PyTypeObject* pytype = nullptr; | |||
auto wrap_tensor(std::shared_ptr<Tensor> t) { | |||
if (pytype) { | |||
return TensorWrapper::make(pytype, std::move(t)); | |||
} | |||
return TensorWrapper::make(std::move(t)); | |||
} | |||
auto wrap_tensor(Tensor* t) { return wrap_tensor(t->shared_from_this()); } | |||
}; | |||
struct CustomBackward { | |||
using BackwardFn = | |||
std::function<apply_result_t(BackwardContext&, Tensor* const*, size_t)>; | |||
BackwardFn m_backward; | |||
SmallVector<bool, 8> m_input_has_grad; | |||
struct OutputAttr { | |||
bool requires_grad = true, captured = true; | |||
}; | |||
SmallVector<OutputAttr> m_output_attrs; | |||
public: | |||
template <typename T, typename R> | |||
void operator()(BackwardContext& ctx, T&& grads, R&& receiver) { | |||
size_t nargs = grads.size(); | |||
Tensor* args[nargs]; | |||
for (size_t i = 0; i < nargs; ++i) { | |||
args[i] = grads[i]; | |||
} | |||
auto ret = m_backward(ctx, args, nargs); | |||
for (size_t i = 0; i < ret.size(); ++i) { | |||
if (auto&& t = ret[i]) { | |||
receiver(i, std::move(t)); | |||
} | |||
} | |||
} | |||
bool input_has_grad(size_t i) { return m_input_has_grad[i]; } | |||
bool output_requires_grad(size_t i) { return m_output_attrs[i].requires_grad; } | |||
bool output_captured(size_t i) { return m_output_attrs[i].captured; } | |||
class Maker { | |||
bool output_size_set = false, input_has_grad_initialized = false; | |||
CustomBackward& target; | |||
ApplyContext& ctx; | |||
void init_input_has_grad() { | |||
if (!input_has_grad_initialized) { | |||
input_has_grad_initialized = true; | |||
target.m_input_has_grad.resize(ctx.nargs, true); | |||
} | |||
} | |||
public: | |||
Maker(CustomBackward& target_, ApplyContext& ctx_) | |||
: target(target_), ctx(ctx_) {} | |||
template <typename F> | |||
Maker& backward(F&& f) { | |||
mgb_assert(!target.m_backward); | |||
target.m_backward = std::forward<F>(f); | |||
return *this; | |||
} | |||
// mandatory | |||
Maker& output_size(size_t sz) { | |||
mgb_assert(!output_size_set); | |||
output_size_set = true; | |||
target.m_output_attrs.resize(sz); | |||
return *this; | |||
} | |||
// optional, defaults to all true | |||
Maker& input_has_grad(size_t i, bool v) { | |||
init_input_has_grad(); | |||
target.m_input_has_grad.at(i) = v; | |||
return *this; | |||
} | |||
// optional, defaults to all true | |||
Maker& output_requires_grad(size_t i, bool v) { | |||
target.m_output_attrs.at(i).requires_grad = v; | |||
return *this; | |||
} | |||
// optional, defaults to all true | |||
Maker& output_captured(size_t i, bool v) { | |||
target.m_output_attrs.at(i).captured = v; | |||
return *this; | |||
} | |||
void finalize() { | |||
mgb_assert(output_size_set); | |||
init_input_has_grad(); | |||
} | |||
}; | |||
Maker maker(ApplyContext& ctx) { return {*this, ctx}; } | |||
}; | |||
using GradRuleFn = std::function<std::optional<apply_result_t>( | |||
ApplyContext&, CustomBackward::Maker&)>; | |||
std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry(); | |||
inline bool input_requires_grad(const ApplyContext& ctx, size_t i) { | |||
return !ctx.args[i]->m_grad_info_dict.empty(); | |||
} | |||
struct GradRuleFallback : std::exception {}; | |||
template <typename T> | |||
bool register_grad_rule(Typeinfo* typeinfo, T&& rule) { | |||
return grad_rule_registry().emplace(typeinfo, std::forward<T>(rule)).second; | |||
} | |||
} // namespace mgb::imperative::python | |||
namespace pybind11::detail { | |||
@@ -1,43 +0,0 @@ | |||
/** | |||
* \file imperative/python/src/grad_info.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include <memory> | |||
#include "./intrusive_list.h" | |||
namespace mgb::imperative::python { | |||
struct GradKey; | |||
struct GradFn; | |||
struct GradSlot; | |||
struct GradSlotPtr { | |||
std::shared_ptr<GradFn> grad_fn; | |||
size_t idx; | |||
operator bool() const; | |||
GradSlot* operator->(); | |||
}; | |||
struct GradInfo : GradSlotPtr, | |||
intrusive_list::Node<GradInfo, intrusive_list::before_t> { | |||
GradInfo() = default; | |||
GradInfo(GradInfo&) = default; | |||
GradInfo(GradInfo&&) = default; | |||
GradInfo& operator=(GradInfo&) = default; | |||
GradInfo& operator=(GradInfo&&) = default; | |||
GradInfo(const GradInfo& rhs) : GradInfo(const_cast<GradInfo&>(rhs)) {} | |||
GradInfo& operator=(const GradInfo& rhs) { | |||
return *this = const_cast<GradInfo&>(rhs); | |||
} | |||
}; | |||
} // namespace mgb::imperative::python |
@@ -11,261 +11,334 @@ | |||
#include "./grad.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/imperative/transformations/grad.h" | |||
namespace mgb::imperative::python { | |||
class CustomGradMaker { | |||
bool output_size_set = false, input_has_grad_initialized = false; | |||
CustomBackward& target; | |||
size_t nr_inputs; | |||
void init_input_has_grad() { | |||
if (!input_has_grad_initialized) { | |||
input_has_grad_initialized = true; | |||
target.m_input_has_grad.resize(nr_inputs, true); | |||
} | |||
} | |||
public: | |||
CustomGradMaker(CustomBackward& target, size_t nr_inputs) | |||
: target(target), nr_inputs(nr_inputs) {} | |||
CustomGradMaker& backward(CustomBackward::BackwardFn f) { | |||
mgb_assert(!target.m_backward); | |||
target.m_backward = f; | |||
return *this; | |||
} | |||
// mandatory | |||
CustomGradMaker& output_size(size_t sz) { | |||
mgb_assert(!output_size_set); | |||
output_size_set = true; | |||
target.m_output_attrs.resize(sz); | |||
return *this; | |||
} | |||
// optional, defaults to all true | |||
CustomGradMaker& input_has_grad(size_t i, bool v) { | |||
init_input_has_grad(); | |||
target.m_input_has_grad.at(i) = v; | |||
return *this; | |||
} | |||
// optional, defaults to all true | |||
CustomGradMaker& output_requires_grad(size_t i, bool v) { | |||
target.m_output_attrs.at(i).requires_grad = v; | |||
return *this; | |||
} | |||
// optional, defaults to all true | |||
CustomGradMaker& output_captured(size_t i, bool v) { | |||
target.m_output_attrs.at(i).captured = v; | |||
return *this; | |||
} | |||
void finalize() { | |||
mgb_assert(output_size_set); | |||
init_input_has_grad(); | |||
} | |||
}; | |||
namespace { | |||
std::shared_ptr<Tensor> get_shape(Tensor* x) { | |||
ValueRef get_shape(ValueRef x) { | |||
static auto op = GetVarShape::make(); | |||
return python::apply(op, x)[0]; | |||
return imperative::apply(*op, x)[0]; | |||
} | |||
std::shared_ptr<Tensor> reduce_to(Tensor* x, Tensor* s) { | |||
ValueRef reduce_to(ValueRef x, ValueRef s) { | |||
static auto op = Reduce::make(); | |||
return python::apply(op, x, s)[0]; | |||
return imperative::apply(*op, x, s)[0]; | |||
} | |||
std::shared_ptr<Tensor> reshape_to(Tensor* x, Tensor* s) { | |||
ValueRef reshape_to(ValueRef x, ValueRef s) { | |||
static auto op = Reshape::make(); | |||
return python::apply(op, x, s)[0]; | |||
return imperative::apply(*op, x, s)[0]; | |||
} | |||
std::shared_ptr<Tensor> broadcast_to(Tensor* x, Tensor* s) { | |||
ValueRef broadcast_to(ValueRef x, ValueRef s) { | |||
static auto op = Broadcast::make(); | |||
return python::apply(op, x, s)[0]; | |||
return imperative::apply(*op, x, s)[0]; | |||
} | |||
std::shared_ptr<Tensor> make_empty_tensor(CompNode cn, Tensor* shape, DType dtype) { | |||
HostTensorND scalar{cn, {{1}, dtype}}; | |||
std::memset(scalar.raw_ptr(), 0, dtype.size()); | |||
interpreter::Interpreter::Handle handle = interpreter_for_py->put(scalar, false); | |||
auto&& t = std::make_shared<Tensor>(handle); | |||
auto res = broadcast_to(t.get(), shape); | |||
ValueRef make_empty_tensor( | |||
CompNodeValue::ref_t device, ValueRef shape, DTypeValue::ref_t dtype) { | |||
HostTensorStorage storage(*device); | |||
storage.ensure_size(dtype->size()); | |||
std::memset(storage.ptr(), 0, dtype->size()); | |||
auto t = imperative::apply( | |||
CreateTensor(CreateTensor::Unique, *device, *dtype, ValueShape()), | |||
HostStorage::make(storage))[0]; | |||
auto res = broadcast_to(t, shape); | |||
return res; | |||
} | |||
std::optional<apply_result_t> elemwise_grad_rule( | |||
ApplyContext& ctx, CustomBackward::Maker& maker) { | |||
auto& op = ctx.op->cast_final_safe<Elemwise>(); | |||
if (op.mode == Elemwise::Mode::ADD) { | |||
mgb_assert(ctx.nargs == 2); | |||
std::array<std::shared_ptr<Tensor>, 2> input_shapes; | |||
std::optional<std::vector<ValueRef>> elemwise_grad_rule( | |||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
CustomBackward& backward) { | |||
auto& elemwise = op.cast_final_safe<Elemwise>(); | |||
if (elemwise.mode != Elemwise::Mode::ADD) { | |||
return {}; | |||
} | |||
mgb_assert(inputs.size() == 2); | |||
std::array<ValueRef, 2> input_shapes; | |||
for (size_t i = 0; i < 2; ++i) { | |||
if (inputs_require_grad[i]) { | |||
input_shapes[i] = get_shape(inputs[i]); | |||
} | |||
} | |||
auto maker = CustomGradMaker(backward, inputs.size()); | |||
maker.output_size(1).output_captured(0, false); | |||
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { | |||
mgb_assert(grads.size() == 1); | |||
ValueRef grad = grads[0]; | |||
std::vector<ValueRef> ret(2); | |||
if (!grad) { | |||
return ret; | |||
} | |||
for (size_t i = 0; i < 2; ++i) { | |||
if (input_requires_grad(ctx, i)) { | |||
input_shapes[i] = get_shape(ctx.args[i]); | |||
if (shapes[i]) { | |||
ret[i] = reduce_to(grad, shapes[i]); | |||
} | |||
} | |||
maker.output_size(1).output_captured(0, false); | |||
maker.backward([shapes = std::move(input_shapes)]( | |||
BackwardContext&, Tensor* const* grads, size_t ngrads) { | |||
mgb_assert(ngrads == 1); | |||
Tensor* grad = grads[0]; | |||
apply_result_t ret(2); | |||
if (!grad) { | |||
return ret; | |||
} | |||
for (size_t i = 0; i < 2; ++i) { | |||
if (shapes[i]) { | |||
ret[i] = reduce_to(grad, shapes[i].get()); | |||
} | |||
} | |||
return ret; | |||
}); | |||
return apply(ctx); | |||
} | |||
return {}; | |||
return ret; | |||
}); | |||
maker.finalize(); | |||
return imperative::apply(ApplyOp(op), inputs); | |||
} | |||
std::optional<apply_result_t> reshape_grad_rule( | |||
ApplyContext& ctx, CustomBackward::Maker& maker) { | |||
mgb_assert(ctx.nargs == 2); | |||
std::array<std::shared_ptr<Tensor>, 2> input_shapes; | |||
std::optional<std::vector<ValueRef>> reshape_grad_rule( | |||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
CustomBackward& backward) { | |||
mgb_assert(inputs.size() == 2); | |||
std::array<ValueRef, 2> input_shapes; | |||
for (size_t i = 0; i < 2; ++i) { | |||
if (input_requires_grad(ctx, i)) { | |||
input_shapes[i] = get_shape(ctx.args[i]); | |||
if (inputs_require_grad[i]) { | |||
input_shapes[i] = get_shape(inputs[i]); | |||
} | |||
} | |||
auto maker = CustomGradMaker(backward, inputs.size()); | |||
maker.output_size(1).output_captured(0, false); | |||
maker.backward([shapes = std::move(input_shapes)]( | |||
BackwardContext&, Tensor* const* grads, size_t ngrads) { | |||
mgb_assert(ngrads == 1); | |||
Tensor* grad = grads[0]; | |||
apply_result_t ret(2); | |||
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { | |||
mgb_assert(grads.size() == 1); | |||
ValueRef grad = grads[0]; | |||
std::vector<ValueRef> ret(2); | |||
if (!grad) { | |||
return ret; | |||
} | |||
for (size_t i = 0; i < 2; ++i) { | |||
if (shapes[i]) { | |||
ret[i] = reshape_to(grad, shapes[i].get()); | |||
ret[i] = reshape_to(grad, shapes[i]); | |||
} | |||
} | |||
return ret; | |||
}); | |||
return apply(ctx); | |||
maker.finalize(); | |||
return imperative::apply(ApplyOp(op), inputs); | |||
} | |||
std::optional<apply_result_t> subtensor_grad_rule( | |||
ApplyContext& ctx, CustomBackward::Maker& maker) { | |||
auto&& op = ctx.op->cast_final_safe<Subtensor>(); | |||
auto&& grad_op = SetSubtensor::make(op.items); | |||
SmallVector<std::shared_ptr<Tensor>> inputs; | |||
if (input_requires_grad(ctx, 0)) { | |||
inputs.push_back(get_shape(ctx.args[0])); | |||
for (size_t i = 1; i < ctx.nargs; ++i) { | |||
inputs.push_back(ctx.args[i]->copy()); | |||
std::optional<std::vector<ValueRef>> subtensor_grad_rule( | |||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
CustomBackward& backward) { | |||
auto&& subtensor = op.cast_final_safe<Subtensor>(); | |||
auto&& grad_op = SetSubtensor::make(subtensor.items); | |||
SmallVector<ValueRef> inputs2; | |||
if (inputs_require_grad[0]) { | |||
inputs2.push_back(get_shape(inputs[0])); | |||
for (size_t i = 1; i < inputs.size(); ++i) { | |||
inputs2.push_back(inputs[i]); | |||
} | |||
} | |||
auto maker = CustomGradMaker(backward, inputs.size()); | |||
maker.output_size(1).output_captured(0, false); | |||
maker.backward([inputs = std::move(inputs), grad_op_ = std::move(grad_op)]( | |||
BackwardContext&, Tensor* const* grads, size_t ngrads) { | |||
mgb_assert(ngrads == 1); | |||
Tensor* grad = grads[0]; | |||
apply_result_t ret(1); | |||
maker.backward([inputs = std::move(inputs2), | |||
grad_op_ = std::move(grad_op)](Span<ValueRef> grads) { | |||
mgb_assert(grads.size() == 1); | |||
ValueRef grad = grads[0]; | |||
std::vector<ValueRef> ret(1); | |||
if (grad && inputs[0]) { | |||
SmallVector<Tensor*> args_(inputs.size() + 1); | |||
auto&& zeros = make_empty_tensor( | |||
grad->comp_node(), inputs[0].get(), grad->dtype()); | |||
args_[0] = zeros.get(); | |||
SmallVector<ValueRef> args_(inputs.size() + 1); | |||
auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype()); | |||
args_[0] = zeros; | |||
args_[1] = grad; | |||
for (size_t i = 1; i < inputs.size(); ++i) { | |||
args_[i + 1] = inputs[i].get(); | |||
args_[i + 1] = inputs[i]; | |||
} | |||
ret[0] = python::apply(grad_op_, args_)[0]; | |||
ret[0] = imperative::apply(ApplyOp(*grad_op_), args_)[0]; | |||
} | |||
return ret; | |||
}); | |||
return apply(ctx); | |||
maker.finalize(); | |||
return imperative::apply(ApplyOp(op), inputs); | |||
} | |||
std::optional<apply_result_t> indexingMultiAxisVec_grad_rule( | |||
ApplyContext& ctx, CustomBackward::Maker& maker) { | |||
auto&& op = ctx.op->cast_final_safe<IndexingMultiAxisVec>(); | |||
auto&& grad_op = IndexingSetMultiAxisVec::make(op.items); | |||
SmallVector<std::shared_ptr<Tensor>> inputs; | |||
if (input_requires_grad(ctx, 0)) { | |||
inputs.push_back(get_shape(ctx.args[0])); | |||
for (size_t i = 1; i < ctx.nargs; ++i) { | |||
inputs.push_back(ctx.args[i]->copy()); | |||
std::optional<std::vector<ValueRef>> indexingMultiAxisVec_grad_rule( | |||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
CustomBackward& backward) { | |||
auto&& indexingMultiAxisVec = op.cast_final_safe<IndexingMultiAxisVec>(); | |||
auto&& grad_op = IndexingSetMultiAxisVec::make(indexingMultiAxisVec.items); | |||
SmallVector<ValueRef> inputs2; | |||
if (inputs_require_grad[0]) { | |||
inputs2.push_back(get_shape(inputs[0])); | |||
for (size_t i = 1; i < inputs.size(); ++i) { | |||
inputs2.push_back(inputs[i]); | |||
} | |||
} | |||
auto maker = CustomGradMaker(backward, inputs.size()); | |||
maker.output_size(1).output_captured(0, false); | |||
maker.backward([inputs = std::move(inputs), grad_op_ = std::move(grad_op)]( | |||
BackwardContext&, Tensor* const* grads, size_t ngrads) { | |||
mgb_assert(ngrads == 1); | |||
Tensor* grad = grads[0]; | |||
apply_result_t ret(1); | |||
maker.backward([inputs = std::move(inputs2), | |||
grad_op_ = std::move(grad_op)](Span<ValueRef> grads) { | |||
mgb_assert(grads.size() == 1); | |||
ValueRef grad = grads[0]; | |||
std::vector<ValueRef> ret(1); | |||
if (grad && inputs[0]) { | |||
SmallVector<Tensor*> args_(inputs.size() + 1); | |||
auto&& zeros = make_empty_tensor( | |||
grad->comp_node(), inputs[0].get(), grad->dtype()); | |||
args_[0] = zeros.get(); | |||
SmallVector<ValueRef> args_(inputs.size() + 1); | |||
auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype()); | |||
args_[0] = zeros; | |||
args_[1] = grad; | |||
for (size_t i = 1; i < inputs.size(); ++i) { | |||
args_[i + 1] = inputs[i].get(); | |||
args_[i + 1] = inputs[i]; | |||
} | |||
ret[0] = python::apply(grad_op_, args_)[0]; | |||
ret[0] = imperative::apply(ApplyOp(*grad_op_), args_)[0]; | |||
} | |||
return ret; | |||
}); | |||
return apply(ctx); | |||
maker.finalize(); | |||
return imperative::apply(ApplyOp(op), inputs); | |||
} | |||
std::optional<apply_result_t> reduce_grad_rule( | |||
ApplyContext& ctx, CustomBackward::Maker& maker) { | |||
auto& op = ctx.op->cast_final_safe<Reduce>(); | |||
if (op.mode == Reduce::Mode::SUM) { | |||
if (ctx.nargs != 1) { | |||
return {}; | |||
} | |||
std::array<std::shared_ptr<Tensor>, 1> input_shapes; | |||
if (input_requires_grad(ctx, 0)) { | |||
input_shapes[0] = get_shape(ctx.args[0]); | |||
} | |||
maker.output_size(1).output_captured(0, false); | |||
maker.backward([shapes = std::move(input_shapes)]( | |||
BackwardContext&, Tensor* const* grads, size_t ngrads) { | |||
mgb_assert(ngrads == 1); | |||
Tensor* grad = grads[0]; | |||
apply_result_t ret(1); | |||
if (grad && shapes[0]) { | |||
ret[0] = broadcast_to(grad, shapes[0].get()); | |||
} | |||
return ret; | |||
}); | |||
return apply(ctx); | |||
std::optional<std::vector<ValueRef>> reduce_grad_rule( | |||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
CustomBackward& backward) { | |||
auto& reduce = op.cast_final_safe<Reduce>(); | |||
if (reduce.mode != Reduce::Mode::SUM) { | |||
return {}; | |||
} | |||
if (inputs.size() != 1) { | |||
return {}; | |||
} | |||
std::array<ValueRef, 1> input_shapes; | |||
if (inputs_require_grad[0]) { | |||
input_shapes[0] = get_shape(inputs[0]); | |||
} | |||
return {}; | |||
auto maker = CustomGradMaker(backward, inputs.size()); | |||
maker.output_size(1).output_captured(0, false); | |||
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { | |||
mgb_assert(grads.size() == 1); | |||
ValueRef grad = grads[0]; | |||
std::vector<ValueRef> ret(1); | |||
if (grad && shapes[0]) { | |||
ret[0] = broadcast_to(grad, shapes[0]); | |||
} | |||
return ret; | |||
}); | |||
maker.finalize(); | |||
return imperative::apply(ApplyOp(op), inputs); | |||
} | |||
std::optional<apply_result_t> addAxis_grad_rule( | |||
ApplyContext& ctx, CustomBackward::Maker& maker) { | |||
auto&& op = ctx.op->cast_final_safe<AddAxis>(); | |||
mgb_assert(ctx.nargs == 1); | |||
bool flag = input_requires_grad(ctx, 0); | |||
auto&& grad_op = RemoveAxis::make(op.axis); | |||
std::optional<std::vector<ValueRef>> addAxis_grad_rule( | |||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
CustomBackward& backward) { | |||
auto&& addAxis = op.cast_final_safe<AddAxis>(); | |||
mgb_assert(inputs.size() == 1); | |||
bool flag = inputs_require_grad[0]; | |||
auto&& grad_op = RemoveAxis::make(addAxis.axis); | |||
std::sort(grad_op->axis.begin(), grad_op->axis.end(), std::greater<int32_t>()); | |||
auto maker = CustomGradMaker(backward, inputs.size()); | |||
maker.output_size(1).output_captured(0, false); | |||
maker.backward([grad_op_ = std::move(grad_op), flag_ = flag]( | |||
BackwardContext&, Tensor* const* grads, size_t ngrads) { | |||
mgb_assert(ngrads == 1); | |||
Tensor* grad = grads[0]; | |||
apply_result_t ret(1); | |||
maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) { | |||
mgb_assert(grads.size() == 1); | |||
ValueRef grad = grads[0]; | |||
std::vector<ValueRef> ret(1); | |||
if (grad && flag_) { | |||
ret[0] = python::apply(grad_op_, grad)[0]; | |||
ret[0] = imperative::apply(*grad_op_, grad)[0]; | |||
} | |||
return ret; | |||
}); | |||
return apply(ctx); | |||
maker.finalize(); | |||
return imperative::apply(op, inputs); | |||
} | |||
std::optional<apply_result_t> removeAxis_grad_rule( | |||
ApplyContext& ctx, CustomBackward::Maker& maker) { | |||
auto&& op = ctx.op->cast_final_safe<RemoveAxis>(); | |||
mgb_assert(ctx.nargs == 1); | |||
bool flag = input_requires_grad(ctx, 0); | |||
auto&& grad_op = AddAxis::make(op.axis); | |||
std::optional<std::vector<ValueRef>> removeAxis_grad_rule( | |||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
CustomBackward& backward) { | |||
auto&& removeAxis = op.cast_final_safe<RemoveAxis>(); | |||
mgb_assert(inputs.size() == 1); | |||
bool flag = inputs_require_grad[0]; | |||
auto&& grad_op = AddAxis::make(removeAxis.axis); | |||
std::sort(grad_op->axis.begin(), grad_op->axis.end()); | |||
auto maker = CustomGradMaker(backward, inputs.size()); | |||
maker.output_size(1).output_captured(0, false); | |||
maker.backward([grad_op_ = std::move(grad_op), flag_ = flag]( | |||
BackwardContext&, Tensor* const* grads, size_t ngrads) { | |||
mgb_assert(ngrads == 1); | |||
Tensor* grad = grads[0]; | |||
apply_result_t ret(1); | |||
maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) { | |||
mgb_assert(grads.size() == 1); | |||
ValueRef grad = grads[0]; | |||
std::vector<ValueRef> ret(1); | |||
if (grad && flag_) { | |||
ret[0] = python::apply(grad_op_, grad)[0]; | |||
ret[0] = imperative::apply(*grad_op_, grad)[0]; | |||
} | |||
return ret; | |||
}); | |||
return apply(ctx); | |||
maker.finalize(); | |||
return imperative::apply(op, inputs); | |||
} | |||
std::optional<apply_result_t> fastpathcopy_grad_rule( | |||
ApplyContext& ctx, CustomBackward::Maker& maker) { | |||
mgb_assert(ctx.nargs == 1); | |||
std::optional<std::vector<ValueRef>> fastpathcopy_grad_rule( | |||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
CustomBackward& backward) { | |||
mgb_assert(inputs.size() == 1); | |||
auto maker = CustomGradMaker(backward, inputs.size()); | |||
maker.output_size(1).output_captured(0, false); | |||
maker.backward([](BackwardContext&, Tensor* const* grads, size_t ngrads) { | |||
mgb_assert(ngrads == 1); | |||
Tensor* grad = grads[0]; | |||
apply_result_t ret(1); | |||
maker.backward([](Span<ValueRef> grads) { | |||
mgb_assert(grads.size() == 1); | |||
ValueRef grad = grads[0]; | |||
std::vector<ValueRef> ret(1); | |||
if (grad) { | |||
ret[0] = grad->shared_from_this(); | |||
ret[0] = grad; | |||
} | |||
return ret; | |||
}); | |||
return apply(ctx); | |||
maker.finalize(); | |||
return imperative::apply(op, inputs); | |||
} | |||
struct Init { | |||
Init() { | |||
auto& reg = grad_rule_registry(); | |||
reg.emplace(Elemwise::typeinfo(), elemwise_grad_rule); | |||
reg.emplace(Reshape::typeinfo(), reshape_grad_rule); | |||
reg.emplace(Subtensor::typeinfo(), subtensor_grad_rule); | |||
reg.emplace(IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule); | |||
reg.emplace(Reduce::typeinfo(), reduce_grad_rule); | |||
reg.emplace(AddAxis::typeinfo(), addAxis_grad_rule); | |||
reg.emplace(RemoveAxis::typeinfo(), removeAxis_grad_rule); | |||
reg.emplace(FastpathCopy::typeinfo(), fastpathcopy_grad_rule); | |||
CustomBackward::register_grad_rule(Elemwise::typeinfo(), elemwise_grad_rule); | |||
CustomBackward::register_grad_rule(Reshape::typeinfo(), reshape_grad_rule); | |||
CustomBackward::register_grad_rule(Subtensor::typeinfo(), subtensor_grad_rule); | |||
CustomBackward::register_grad_rule( | |||
IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule); | |||
CustomBackward::register_grad_rule(Reduce::typeinfo(), reduce_grad_rule); | |||
CustomBackward::register_grad_rule(AddAxis::typeinfo(), addAxis_grad_rule); | |||
CustomBackward::register_grad_rule( | |||
RemoveAxis::typeinfo(), removeAxis_grad_rule); | |||
CustomBackward::register_grad_rule( | |||
FastpathCopy::typeinfo(), fastpathcopy_grad_rule); | |||
} | |||
} _; | |||
@@ -1,245 +0,0 @@ | |||
/** | |||
* \file imperative/python/src/intrusive_list.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/utils/metahelper.h" | |||
namespace mgb::imperative::python::intrusive_list { | |||
// copy policy | |||
struct after_t {}; | |||
struct before_t {}; | |||
struct disable_t {}; | |||
template <typename T> | |||
struct Tail; | |||
// invariant: next->prev == this | |||
template <typename T> | |||
struct Head { | |||
Tail<T>* next; | |||
Head(Tail<T>* node = nullptr) : next(node) {} | |||
Head(const Head<T>&) = delete; | |||
Head<T>& operator=(const Head<T>&) = delete; | |||
Head(Head<T>&& rhs) : next(rhs.next) { | |||
rhs.next = nullptr; | |||
if (next) { | |||
next->prev = this; | |||
} | |||
} | |||
Head<T>& operator=(Head<T>&& rhs) { | |||
mgb_assert(!next); | |||
next = rhs.next; | |||
rhs.next = nullptr; | |||
if (next) { | |||
next->prev = this; | |||
} | |||
return *this; | |||
} | |||
~Head() { | |||
if (next) { | |||
next->prev = nullptr; | |||
} | |||
} | |||
}; | |||
// invariant: prev->next == this | |||
template <typename T> | |||
struct Tail { | |||
Head<T>* prev; | |||
Tail(Head<T>* node = nullptr) : prev(node) {} | |||
Tail(const Tail<T>&) = delete; | |||
Tail<T>& operator=(const Tail<T>&) = delete; | |||
Tail(Tail<T>&& rhs) : prev(rhs.prev) { | |||
rhs.prev = nullptr; | |||
if (prev) { | |||
prev->next = this; | |||
} | |||
} | |||
Tail<T>& operator=(Tail<T>&& rhs) { | |||
mgb_assert(!prev); | |||
prev = rhs.prev; | |||
rhs.prev = nullptr; | |||
if (prev) { | |||
prev->next = this; | |||
} | |||
return *this; | |||
} | |||
~Tail() { | |||
if (prev) { | |||
prev->next = nullptr; | |||
} | |||
} | |||
}; | |||
template <typename T, typename policy> | |||
struct Node; | |||
template <typename T> | |||
class Iterator { | |||
T* ptr; | |||
void inc() { ptr = static_cast<T*>(ptr->Head<T>::next); } | |||
void dec() { ptr = static_cast<T*>(ptr->Head<T>::prev); } | |||
public: | |||
Iterator(Head<T>& head) : ptr(static_cast<T*>(head.next)) {} | |||
Iterator(Tail<T>& tail) : ptr(static_cast<T*>(tail.prev)) {} | |||
template <typename policy> | |||
Iterator(Node<T, policy>& node) : ptr(static_cast<T*>(&node)) {} | |||
T& operator*() { return *static_cast<T*>(ptr); } | |||
T* operator->() { return static_cast<T*>(ptr); } | |||
operator bool() { return ptr; } | |||
bool operator==(const Iterator<T>& rhs) { return ptr == rhs.ptr; } | |||
Iterator& operator++() { | |||
inc(); | |||
return *this; | |||
} | |||
Iterator& operator--() { | |||
dec(); | |||
return *this; | |||
} | |||
Iterator operator++(int) { | |||
auto ret = *this; | |||
inc(); | |||
return ret; | |||
} | |||
Iterator operator--(int) { | |||
auto ret = *this; | |||
dec(); | |||
return ret; | |||
} | |||
}; | |||
// Node in a doubly linked list. Unlike std::list, nodes are not owned by a container. | |||
// Instead, nodes may join or leave a list freely. | |||
// NOTE: Derived classes have to explicitly declare copy / assignment as default, | |||
// otherwise the compiler generated version would use the const T& signature, | |||
// which is deleted. | |||
template <typename T = void, typename policy = disable_t> | |||
struct Node : Tail<std::conditional_t<std::is_same_v<T, void>, Node<T, policy>, T>>, | |||
Head<std::conditional_t<std::is_same_v<T, void>, Node<T, policy>, T>> { | |||
private: | |||
using this_t = Node<T, policy>; | |||
using U = std::conditional_t<std::is_same_v<T, void>, this_t, T>; | |||
public: | |||
using head_t = Head<U>; | |||
using tail_t = Tail<U>; | |||
using head_t::next; | |||
using tail_t::prev; | |||
Node() = default; | |||
Node(const this_t&) = delete; | |||
this_t& operator=(const this_t&) = delete; | |||
//! constructed node is inserted after the input node | |||
Node(after_t, head_t& node) : tail_t(&node), head_t(node.next) { | |||
node.next = this; | |||
if (next) { | |||
next->prev = this; | |||
} | |||
} | |||
//! constructed node is inserted before the input node | |||
Node(before_t, tail_t& node) : head_t(&node), tail_t(node.prev) { | |||
node.prev = this; | |||
if (prev) { | |||
prev->next = this; | |||
} | |||
} | |||
Node(this_t&& rhs) : tail_t(rhs.prev), head_t(rhs.next) { | |||
rhs.prev = nullptr; | |||
rhs.next = nullptr; | |||
if (prev) { | |||
prev->next = this; | |||
} | |||
if (next) { | |||
next->prev = this; | |||
} | |||
} | |||
Node& operator=(this_t&& rhs) { | |||
unlink(); | |||
prev = rhs.prev; | |||
next = rhs.next; | |||
rhs.prev = nullptr; | |||
rhs.next = nullptr; | |||
if (prev) { | |||
prev->next = this; | |||
} | |||
if (next) { | |||
next->prev = this; | |||
} | |||
return *this; | |||
} | |||
template < | |||
typename p = policy, | |||
typename = std::enable_if_t< | |||
std::is_same_v<p, before_t> || std::is_same_v<p, after_t>, void>> | |||
Node(this_t& rhs) : Node(policy{}, rhs) {} | |||
template < | |||
typename p = policy, | |||
typename = std::enable_if_t< | |||
std::is_same_v<p, before_t> || std::is_same_v<p, after_t>, void>> | |||
this_t& operator=(this_t& rhs) { | |||
insert(policy{}, rhs); | |||
return *this; | |||
} | |||
void unlink() { | |||
if (prev) { | |||
prev->next = next; | |||
} | |||
if (next) { | |||
next->prev = prev; | |||
} | |||
prev = nullptr; | |||
next = nullptr; | |||
} | |||
//! this node is unlinked from its list and inserted after the input node | |||
void insert(after_t, head_t& node) { | |||
unlink(); | |||
prev = &node; | |||
next = node.next; | |||
node.next = this; | |||
if (next) { | |||
next->prev = this; | |||
} | |||
} | |||
//! this node is unlinked from its list and inserted before the input node | |||
void insert(before_t, tail_t& node) { | |||
unlink(); | |||
next = &node; | |||
prev = node.prev; | |||
node.prev = this; | |||
if (prev) { | |||
prev->next = this; | |||
} | |||
} | |||
void insert_before(tail_t& node) { insert(before_t{}, node); } | |||
void insert_after(head_t& node) { insert(after_t{}, node); } | |||
~Node() { unlink(); } | |||
}; | |||
} // namespace mgb::imperative::python::intrusive_list |
@@ -1,42 +0,0 @@ | |||
/** | |||
* \file imperative/python/src/module_trace.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "./module_trace.h" | |||
#include "./helper.h" // include op pybind11 caster | |||
namespace py = pybind11; | |||
namespace mgb::imperative::python { | |||
apply_result_t apply_module_trace(ApplyContext& ctx) { | |||
apply_result_t outputs; | |||
auto args = py::tuple(ctx.nargs + 1); | |||
args[0] = py::cast(ctx.op); | |||
for (size_t i = 0; i < ctx.nargs; i++) { | |||
args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this()); | |||
} | |||
auto pyout = PyObject_Call(cpp_apply_module_trace, args.ptr(), nullptr); | |||
if (!pyout) | |||
throw py::error_already_set(); | |||
auto ret = py::reinterpret_steal<py::object>(pyout); | |||
// assumption: python function always returns PyList | |||
auto tup = py::reinterpret_borrow<py::list>(ret); | |||
for (auto i = 0; i < tup.size(); i++) { | |||
auto tw = TensorWrapper::try_cast(tup[i].ptr()); | |||
outputs.emplace_back(tw->m_tensor); | |||
} | |||
return outputs; | |||
} | |||
} // namespace mgb::imperative::python |
@@ -11,10 +11,50 @@ | |||
#pragma once | |||
#include "megbrain/imperative/transformations/trace.h" | |||
#include "megbrain/imperative/utils/map.h" | |||
#include "./tensor.h" | |||
namespace mgb::imperative::python { | |||
apply_result_t apply_module_trace(ApplyContext& ctx); | |||
namespace py = pybind11; | |||
class ModuleTraceTransformation final : public Transformation { | |||
private: | |||
py::function m_hook_fn; | |||
int m_enabled = 0; | |||
std::vector<ValueRef> apply_module_trace_hook( | |||
const OpDef& op, Span<ValueRef> input_values) { | |||
py::list input_tws; | |||
for (auto&& input_value : input_values) { | |||
input_tws.append(TensorWrapper::make(py_tensor_type, input_value)); | |||
} | |||
py::list output_tws = m_hook_fn(py::cast(op.shared_from_this()), *input_tws); | |||
std::vector<ValueRef> outputs; | |||
for (auto&& output_tw : output_tws) { | |||
outputs.push_back( | |||
TensorWrapper::try_cast(output_tw.ptr())->m_tensor->data()); | |||
} | |||
return outputs; | |||
} | |||
public: | |||
ModuleTraceTransformation(py::function hook_fn) : m_hook_fn(hook_fn) {} | |||
std::vector<ValueRef> apply_transformation( | |||
const Operator& op, Span<ValueRef> inputs) override { | |||
if (op.is<ApplyOp>() && m_enabled > 0) { | |||
auto outputs = apply_module_trace_hook(op.cast<ApplyOp>().op(), inputs); | |||
return outputs; | |||
} else { | |||
return imperative::apply(op, inputs); | |||
} | |||
} | |||
ValueRef unwrap(ValueRef value) override { return value; } | |||
std::string name() const override { return "ModuleTraceTransformation"; } | |||
}; | |||
} // namespace mgb::imperative::python |
@@ -185,7 +185,8 @@ int py_set_scope(PyObject* obj, PyObject* value, void* /* closure */) { | |||
} | |||
PyGetSetDef PyOp(OpDef)::py_getsetters[] = { | |||
{const_cast<char*>("scope"), py_get_scope, py_set_scope, "scope", NULL}, | |||
{const_cast<char*>("scope"), py_get_scope, py_set_scope, | |||
const_cast<char*>("scope"), NULL}, | |||
{NULL}}; | |||
Py_hash_t PyOp(OpDef)::tp_hash(PyObject* obj) { | |||
@@ -556,12 +557,6 @@ void init_ops(py::module m) { | |||
m.def( | |||
"delete_rng_handle", | |||
[](size_t handle) { | |||
// RNG op might execute after handle released due to async dispatch, so | |||
// we need sync before delete a handle to avoid memory leak or | |||
// use-after-free | |||
if (python::interpreter_for_py->check_available()) { | |||
python::interpreter_for_py->sync(); | |||
} | |||
mgb::CompNode::sync_all(); | |||
py_task_q.wait_all_task_finish(); | |||
rng::delete_handle(handle); | |||
@@ -20,6 +20,8 @@ | |||
#include "pybind11/pybind11.h" | |||
#include "./pyext17.h" | |||
#include "megbrain/imperative/dispatch.h" | |||
#include "megbrain/imperative/utils/span.h" | |||
namespace mgb::imperative::python { | |||
@@ -32,126 +34,67 @@ struct ObjectPtr : B { | |||
} // namespace mgb::imperative::python | |||
#include "./grad_info.h" // for struct GradInfo | |||
#include "./trace_info.h" // for struct TraceInfo | |||
namespace mgb::imperative::python { | |||
struct GradKey; | |||
extern interpreter::Interpreter::Channel* interpreter_for_py; | |||
extern PyTypeObject* py_tensor_type; | |||
class SharedHandle { | |||
using Handle = interpreter::Interpreter::Handle; | |||
static_assert(std::is_pointer_v<Handle>); | |||
std::shared_ptr<std::remove_pointer_t<Handle>> holder; | |||
public: | |||
inline explicit SharedHandle(Handle handle) | |||
: holder(handle, [](auto* h) { | |||
if (h) { | |||
interpreter_for_py->del(h); | |||
} | |||
}) {} | |||
SharedHandle(const SharedHandle&) = default; | |||
SharedHandle& operator=(const SharedHandle&) = default; | |||
SharedHandle(SharedHandle&&) = default; | |||
SharedHandle& operator=(SharedHandle&&) = default; | |||
inline Handle get() { return holder.get(); } | |||
}; | |||
// impl in grad.cpp | |||
class GradInfoCollection { | |||
struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { | |||
private: | |||
SmallVector<GradInfo> m_storage; | |||
protected: | |||
void _shrink(); | |||
std::string m_name; | |||
ValueRef m_data; | |||
public: | |||
bool contains(GradKey* key); | |||
GradInfo& operator[](GradKey* key); | |||
GradInfo& at(GradKey* key); | |||
bool empty() { | |||
_shrink(); | |||
return m_storage.empty(); | |||
} | |||
auto begin() { | |||
_shrink(); | |||
return m_storage.begin(); | |||
} | |||
auto end() { | |||
_shrink(); | |||
return m_storage.end(); | |||
} | |||
size_t count(GradKey* key) { return contains(key) ? 1 : 0; } | |||
}; | |||
struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { | |||
using flags_t = uint64_t; | |||
struct Flags { | |||
static constexpr flags_t SCALAR = 1; | |||
static constexpr flags_t GRAD = 1 << 1; | |||
static constexpr flags_t TRACE = 1 << 2; | |||
static constexpr flags_t MODULE_TRACE = 1 << 3; | |||
}; | |||
flags_t m_flags = 0; | |||
GradInfoCollection m_grad_info_dict; | |||
TraceInfo m_trace_info; | |||
SharedHandle m_handle; | |||
std::string user_custom_name; | |||
std::string automatic_name; | |||
cg::VarNode* m_var; | |||
pybind11::object m_module_trace_info; | |||
using Handle = interpreter::Interpreter::Handle; | |||
inline Tensor() : m_handle(nullptr), m_var(nullptr) {} | |||
inline explicit Tensor(Handle handle) : m_handle(handle), m_var(nullptr) {} | |||
inline explicit Tensor(SharedHandle handle) | |||
: m_handle(std::move(handle)), m_var(nullptr) {} | |||
inline explicit Tensor(cg::VarNode* var) : m_handle(nullptr), m_var(var) {} | |||
inline explicit Tensor(ValueRef data) : m_data{data} {} | |||
~Tensor() = default; | |||
inline std::shared_ptr<Tensor> copy() { | |||
auto ret = std::make_shared<Tensor>(m_handle); | |||
ret->m_flags = m_flags; | |||
ret->m_grad_info_dict = m_grad_info_dict; | |||
ret->m_trace_info = m_trace_info; | |||
ret->m_var = m_var; | |||
auto ret = std::make_shared<Tensor>(m_data.unwrap()); | |||
ret->m_name = m_name; | |||
return ret; | |||
} | |||
inline DType dtype() { | |||
if (m_var) { | |||
return m_var->dtype(); | |||
inline DType dtype() { return *data().dtype(); } | |||
inline CompNode comp_node() { return *data().device(); } | |||
inline std::optional<ValueShape> shape() { | |||
auto shape = data().shape(); | |||
if (!shape) { | |||
return {}; | |||
} | |||
return interpreter_for_py->get_dtype(m_handle.get()); | |||
return *shape; | |||
} | |||
inline CompNode comp_node() { | |||
if (m_var) { | |||
return m_var->comp_node(); | |||
inline HostValue::ref_t numpy() { return data().numpy(); } | |||
inline void reset(ValueRef value) { | |||
m_data = value; | |||
if (!m_name.empty()) { | |||
set_name(m_name); | |||
} | |||
return interpreter_for_py->get_device(m_handle.get()); | |||
} | |||
inline TensorShape shape() { | |||
if (m_var) { | |||
return m_var->shape(); | |||
inline ValueRef data() { return m_data.unwrap(); } | |||
bool is_scalar() { return data().is_scalar(); } | |||
inline std::string name() { return m_name; } | |||
inline void set_name(std::string name) { | |||
m_name = name; | |||
if (!name.empty()) { | |||
auto output = imperative::apply(RenameValue(name), m_data)[0]; | |||
m_data = output; | |||
} | |||
return interpreter_for_py->get_shape(m_handle.get()); | |||
} | |||
}; | |||
struct TensorWrapper { | |||
public: | |||
std::shared_ptr<Tensor> m_tensor; | |||
inline TensorWrapper(std::shared_ptr<Tensor> tensor = {}) | |||
: m_tensor(std::move(tensor)) {} | |||
: m_tensor(std::move(tensor)) { | |||
mgb_assert(tensor, "empty storage"); | |||
} | |||
inline TensorWrapper(ValueRef value) : m_tensor(std::make_shared<Tensor>(value)) {} | |||
TensorWrapper(PyObject* args, PyObject* kwargs); | |||
~TensorWrapper() = default; | |||
@@ -191,33 +134,17 @@ struct TensorWrapper { | |||
void reset(PyObject*); | |||
PyObject* detach(); | |||
PyObject* isscalar(); | |||
void setscalar(); | |||
void unsetscalar(); | |||
PyObject* _dev_tensor(); | |||
void _drop(); | |||
PyObject* varnode(); | |||
void reset_varnode(); | |||
PyObject* handle(); | |||
void set_handle(PyObject*); | |||
PyObject* mixin_handle(); | |||
PyObject* recording(); | |||
PyObject* copied(); | |||
void set_mixin_handle(PyObject*); | |||
void set_recording(PyObject*); | |||
PyObject* compiled_info(); | |||
void set_compiled_info(PyObject*); | |||
PyObject* trace_mixin_info(); | |||
void set_trace_mixin_info(PyObject*); | |||
PyObject* module_trace_info(); | |||
void set_module_trace_info(PyObject*); | |||
PyObject* user_custom_name(); | |||
void set_user_custom_name(PyObject*); | |||
PyObject* automatic_name(); | |||
void set_automatic_name(PyObject*); | |||
void _set_name(PyObject*); | |||
PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); }; | |||
PyObject* _detail(); | |||
void _watch(); | |||
}; | |||
struct PySymbolVar { | |||
@@ -230,113 +157,8 @@ struct PySymbolVar { | |||
PyObject* py_apply( | |||
PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */); | |||
struct ApplyContext { | |||
static Tensor::flags_t global_disable; | |||
static Tensor::flags_t global_enable; | |||
Tensor::flags_t flags = 0; | |||
std::shared_ptr<OpDef> op; | |||
Tensor* const* args; | |||
size_t nargs; | |||
PyTypeObject* pytype = nullptr; | |||
bool backward = false; | |||
class scoped_disable : NonCopyableObj { | |||
Tensor::flags_t saved_flags; | |||
public: | |||
scoped_disable(Tensor::flags_t flags) | |||
: saved_flags(ApplyContext::global_disable) { | |||
ApplyContext::global_disable |= flags; | |||
} | |||
~scoped_disable() { ApplyContext::global_disable = saved_flags; } | |||
}; | |||
}; | |||
using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>; | |||
apply_result_t apply(ApplyContext& ctx); | |||
template <typename T> | |||
decltype(auto) resolve_arrow(T&& p) { | |||
if constexpr (std::is_pointer_v<std::remove_reference_t<T>>) { | |||
auto* ret = p; | |||
return ret; | |||
} else { | |||
auto probe = [](auto&& p) -> decltype(p.operator->()) {}; | |||
if constexpr (std::is_invocable_v<decltype(probe), decltype(p)>) { | |||
return resolve_arrow(p.operator->()); | |||
} else { | |||
return std::forward<T>(p); | |||
} | |||
} | |||
} | |||
template <typename... Args> | |||
constexpr bool is_all_tensor_ptr = | |||
(... && std::is_same_v<decltype(resolve_arrow(std::declval<Args>())), Tensor*>); | |||
template <typename... Args, std::enable_if_t<is_all_tensor_ptr<Args...>, int> = 0> | |||
apply_result_t apply(std::shared_ptr<OpDef> op, Args&&... args) { | |||
ApplyContext ctx; | |||
Tensor* arg_arr[] = {resolve_arrow(args)...}; | |||
ctx.flags = (0 | ... | args->m_flags); | |||
ctx.args = arg_arr; | |||
ctx.nargs = sizeof...(args); | |||
ctx.op = std::move(op); | |||
return apply(ctx); | |||
} | |||
inline auto apply(std::shared_ptr<OpDef> op, Tensor* const* args, size_t nargs) { | |||
ApplyContext ctx; | |||
ctx.op = std::move(op); | |||
ctx.nargs = nargs; | |||
ctx.args = args; | |||
for (size_t i = 0; i < nargs; ++i) { | |||
ctx.flags |= args[i]->m_flags; | |||
} | |||
return apply(ctx); | |||
} | |||
template <typename T> | |||
auto apply(std::shared_ptr<OpDef> op, T&& tensors) -> std::enable_if_t< | |||
std::is_same_v<decltype(resolve_arrow(tensors[0])), Tensor*>, apply_result_t> { | |||
size_t nargs = tensors.size(); | |||
Tensor* args[nargs]; | |||
for (size_t i = 0; i < nargs; ++i) { | |||
args[i] = resolve_arrow(tensors[i]); | |||
} | |||
return apply(op, args, nargs); | |||
} | |||
std::shared_ptr<Tensor> make_const(imperative::TensorPtr value); | |||
inline auto apply(Subgraph graph, Tensor* const* args, size_t nargs) { | |||
SmallVector<std::shared_ptr<Tensor>> inputs; | |||
for (size_t i = 0; i < nargs; ++i) { | |||
inputs.push_back(args[i]->shared_from_this()); | |||
} | |||
auto apply_functor = [](std::shared_ptr<OpDef> op, | |||
SmallVector<std::shared_ptr<Tensor>> inputs, | |||
size_t) { return apply(op, std::move(inputs)); }; | |||
return graph.apply(inputs, apply_functor, &make_const); | |||
} | |||
template <typename T> | |||
auto apply(Subgraph graph, T&& tensors) -> std::enable_if_t< | |||
std::is_same_v<std::decay_t<decltype(tensors[0])>, Tensor*>, apply_result_t> { | |||
size_t nargs = tensors.size(); | |||
Tensor* args[nargs]; | |||
for (size_t i = 0; i < nargs; ++i) { | |||
args[i] = resolve_arrow(tensors[i]); | |||
} | |||
return apply(graph, args, nargs); | |||
} | |||
void init_tensor(pybind11::module); | |||
extern PyObject* cpp_apply_with_tracing; | |||
extern PyObject* cpp_apply_backward_varnode; | |||
extern PyObject* cpp_apply_module_trace; | |||
} // namespace mgb::imperative::python | |||
@@ -1,63 +0,0 @@ | |||
/** | |||
* \file imperative/python/src/trace.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "./trace.h" | |||
#include "./helper.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
namespace py = pybind11; | |||
namespace mgb::imperative::python { | |||
apply_result_t apply_trace(ApplyContext& ctx) { | |||
apply_result_t outputs; | |||
if (ctx.backward) { | |||
// reach here when compiled=True | |||
auto args = py::tuple(ctx.nargs + 1); | |||
args[0] = py::cast(ctx.op); | |||
for (size_t i = 0; i < ctx.nargs; i++) { | |||
args[i + 1] = py::cast(ctx.args[i]->m_var); | |||
} | |||
py::object pyout = py::reinterpret_steal<py::object>( | |||
PyObject_Call(cpp_apply_backward_varnode, args.ptr(), nullptr)); | |||
if (!pyout) | |||
throw py::error_already_set(); | |||
// assumption: python function always returns PyList | |||
auto tup = py::reinterpret_borrow<py::list>(pyout); | |||
for (size_t i = 0; i < tup.size(); i++) { | |||
auto pitem = tup[i].cast<cg::VarNode*>(); | |||
outputs.emplace_back(std::make_shared<Tensor>(pitem)); | |||
} | |||
return outputs; | |||
} | |||
auto args = py::tuple(ctx.nargs + 1); | |||
args[0] = py::cast(ctx.op); | |||
for (size_t i = 0; i < ctx.nargs; i++) { | |||
args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this()); | |||
} | |||
auto pyout = PyObject_Call(cpp_apply_with_tracing, args.ptr(), nullptr); | |||
if (!pyout) | |||
throw py::error_already_set(); | |||
// assumption: python function always returns PyList | |||
auto tup = py::reinterpret_steal<py::list>(pyout); | |||
for (size_t i = 0; i < tup.size(); i++) { | |||
auto tw = TensorWrapper::try_cast(tup[i].ptr()); | |||
outputs.emplace_back(tw->m_tensor); | |||
} | |||
return outputs; | |||
} | |||
} // namespace mgb::imperative::python |
@@ -1,28 +0,0 @@ | |||
/** | |||
* \file imperative/python/src/trace.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include <stdexcept> | |||
#include "./tensor.h" | |||
namespace mgb::imperative::python { | |||
class TraceReadError : public std::exception { | |||
public: | |||
explicit TraceReadError(const char* m) : message{m} {} | |||
const char* what() const noexcept override { return message.c_str(); } | |||
private: | |||
std::string message = ""; | |||
}; | |||
apply_result_t apply_trace(ApplyContext& ctx); | |||
} // namespace mgb::imperative::python |
@@ -1,49 +0,0 @@ | |||
/** | |||
* \file imperative/python/src/trace_info.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "Python.h" | |||
#include "inttypes.h" | |||
namespace mgb::imperative::python { | |||
struct TraceInfo { | |||
int64_t mixin_handle = -1; | |||
bool recording = false; | |||
// refer to CompiledTensorProxy in tracing.py, works from second trace step | |||
PyObject* compiled_info = nullptr; | |||
// refer to TensorInfo in tracing.py, only works in first trace step | |||
PyObject* trace_mixin_info = nullptr; | |||
TraceInfo() = default; | |||
TraceInfo& operator=(const TraceInfo& that) { | |||
mixin_handle = that.mixin_handle; | |||
recording = that.recording; | |||
trace_mixin_info = that.trace_mixin_info; | |||
Py_XINCREF(trace_mixin_info); | |||
compiled_info = that.compiled_info; | |||
Py_XINCREF(compiled_info); | |||
return *this; | |||
} | |||
~TraceInfo() { | |||
Py_XDECREF(trace_mixin_info); | |||
Py_XDECREF(compiled_info); | |||
} | |||
private: | |||
TraceInfo(const TraceInfo& that) = default; | |||
}; | |||
} // namespace mgb::imperative::python |
@@ -16,10 +16,6 @@ import megengine.module | |||
from megengine import Parameter | |||
from megengine.core._imperative_rt.core2 import sync | |||
from megengine.device import get_device_count | |||
from megengine.experimental.autograd import ( | |||
disable_higher_order_directive, | |||
enable_higher_order_directive, | |||
) | |||
from megengine.jit import trace as _trace | |||
from megengine.module import Linear, Module | |||
@@ -45,13 +41,3 @@ def skip_distributed(request): | |||
platform.system() | |||
) | |||
) | |||
@pytest.fixture(autouse=True) | |||
def resolve_require_higher_order_directive(request): | |||
marker = request.node.get_closest_marker("require_higher_order_directive") | |||
if marker: | |||
enable_higher_order_directive() | |||
yield | |||
if marker: | |||
disable_higher_order_directive() |
@@ -146,5 +146,5 @@ def test_dump_bn_train_mode(): | |||
data = mge.tensor(np.random.random((10, 10, 10, 10))) | |||
bn_train(data) | |||
with pytest.raises(AssertionError): | |||
with pytest.raises(RuntimeError): | |||
bn_train.dump("test.mge") |
@@ -17,7 +17,7 @@ import megengine.distributed as dist | |||
import megengine.functional as F | |||
import megengine.module as M | |||
import megengine.optimizer as optim | |||
from megengine.autodiff import GradManager | |||
from megengine.autodiff import Function, GradManager | |||
from megengine.jit import trace | |||
@@ -214,7 +214,7 @@ def test_remote_grad(trace_mode): | |||
x = dist.functional.remote_recv(rank - 1) | |||
y = m(x) | |||
if rank != size - 1: | |||
dist.functional.remote_send(y, dest_rank=rank + 1) | |||
x = dist.functional.remote_send(y, dest_rank=rank + 1) | |||
gm.backward() | |||
else: | |||
y = y.mean() | |||
@@ -224,7 +224,7 @@ def test_remote_grad(trace_mode): | |||
if trace_mode is not None: | |||
train_func = trace(symbolic=trace_mode)(train_func) | |||
for i in range(3): | |||
for i in range(1): | |||
train_func(x) | |||
worker() | |||
@@ -340,7 +340,6 @@ def test_broadcast_grad(trace_mode): | |||
worker() | |||
@pytest.mark.require_higher_order_directive() | |||
def test_2nd_grad_with_manager(): | |||
x_np = np.random.rand(10).astype("float32") | |||
x = mge.tensor(x_np) | |||
@@ -359,7 +358,6 @@ def test_2nd_grad_with_manager(): | |||
) | |||
@pytest.mark.require_higher_order_directive() | |||
def test_grad_manager_group(): | |||
x_np = np.random.rand(10).astype("float32") | |||
x = mge.tensor(x_np) | |||
@@ -376,7 +374,6 @@ def test_grad_manager_group(): | |||
x.grad = None | |||
@pytest.mark.require_higher_order_directive() | |||
def test_grad_manager_group_visibility(): | |||
x_np = np.random.rand(10).astype("float32") | |||
x = mge.tensor(x_np) | |||
@@ -392,7 +389,6 @@ def test_grad_manager_group_visibility(): | |||
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) | |||
@pytest.mark.require_higher_order_directive() | |||
def test_grad_manager_visibility_by_order(): | |||
x_np = np.random.rand(10).astype("float32") | |||
x = mge.tensor(x_np) | |||
@@ -410,7 +406,6 @@ def test_grad_manager_visibility_by_order(): | |||
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) | |||
@pytest.mark.require_higher_order_directive() | |||
@pytest.mark.parametrize("target", [F.cos, F.sin, lambda x: x * 2 + 1]) | |||
def test_emulate_forward_mode_with_reverse_mode(target): | |||
def jvp(inp, expr): | |||
@@ -434,3 +429,43 @@ def test_emulate_forward_mode_with_reverse_mode(target): | |||
np.testing.assert_almost_equal(y.numpy(), y1.numpy(), decimal=5) | |||
np.testing.assert_almost_equal(dy.numpy(), dy1.numpy(), decimal=3) | |||
def test_2nd_grad_with_custom_gradient(): | |||
class MySin(Function): | |||
def forward(self, x): | |||
self.inp = x | |||
x = mge.Tensor(x.numpy()) | |||
y = F.sin(x) | |||
return y | |||
def backward(self, dy): | |||
dx = F.cos(self.inp) * dy | |||
return dx | |||
class MyCos(Function): | |||
def forward(self, x): | |||
self.inp = x | |||
x = mge.Tensor(x.numpy()) | |||
y = F.cos(x) | |||
return y | |||
def backward(self, dy): | |||
dx = -MySin()(self.inp) * dy | |||
return dx | |||
x_np = np.random.rand(10).astype("float32") | |||
x = mge.tensor(x_np) | |||
gm = GradManager().attach([x]) | |||
gm2 = GradManager().attach([x]) | |||
with gm: | |||
with gm2: | |||
y = MyCos()(x) | |||
gm2.backward(y) | |||
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) | |||
gm.backward(x.grad) | |||
np.testing.assert_almost_equal( | |||
x.grad.numpy(), -np.sin(x_np) - np.cos(x_np), decimal=5 | |||
) |
@@ -7,8 +7,6 @@ | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import gc | |||
import platform | |||
import weakref | |||
import numpy as np | |||
import pytest | |||
@@ -60,24 +58,20 @@ def test_dist_grad(): | |||
def worker(): | |||
rank = dist.get_rank() | |||
if rank == 0: | |||
grad = Grad() | |||
x = as_tensor(x_np) | |||
grad.wrt(x, callback=save_to(x)) | |||
# need a placeholder to trace operator | |||
remote_send(x, 1) | |||
recv_x = remote_recv(1) | |||
y = recv_x * recv_x | |||
grad([y], [as_tensor(np.ones_like(x_np))]) | |||
with Grad() as grad: | |||
x = as_tensor(x_np) | |||
grad.wrt(x, callback=save_to(x)) | |||
# need a placeholder to trace operator | |||
remote_send(x, 1) | |||
recv_x = remote_recv(1) | |||
y = recv_x * recv_x | |||
grad([y], [as_tensor(np.ones_like(x_np))]) | |||
np.testing.assert_almost_equal(x.grad.numpy(), x.numpy() * 2) | |||
elif rank == 1: | |||
grad = Grad() | |||
recv_x = remote_recv(0) | |||
remote_send(recv_x, 0) | |||
grad([], []) | |||
with Grad() as grad: | |||
recv_x = remote_recv(0) | |||
remote_send(recv_x, 0) | |||
grad([], []) | |||
worker() | |||
@@ -86,11 +80,11 @@ def test_grad(): | |||
x_np = np.random.rand(10).astype("float32") | |||
x = as_tensor(x_np) | |||
grad = Grad().wrt(x, callback=save_to(x)) | |||
y = cos(x) | |||
with Grad() as grad: | |||
grad.wrt(x, callback=save_to(x)) | |||
y = cos(x) | |||
grad(y, as_tensor(np.ones_like(x_np))) | |||
grad(y, as_tensor(np.ones_like(x_np))) | |||
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np)) | |||
@@ -98,12 +92,12 @@ def test_grad_2(): | |||
x_np = np.random.rand(10).astype("float32") | |||
x = as_tensor(x_np) | |||
grad = Grad().wrt(x, callback=save_to(x)) | |||
with Grad() as grad: | |||
grad.wrt(x, callback=save_to(x)) | |||
y = mul(x, x) | |||
y = mul(y, y) | |||
grad(y, as_tensor(np.ones_like(x_np))) | |||
y = mul(x, x) | |||
y = mul(y, y) | |||
grad(y, as_tensor(np.ones_like(x_np))) | |||
np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) | |||
@@ -113,32 +107,31 @@ def test_2nd_grad(): | |||
x = as_tensor(x_np) | |||
ones = as_tensor(np.ones_like(x_np)) | |||
grad = Grad().wrt(x, callback=save_to(x)) | |||
grad._priority = -1 | |||
grad2 = Grad().wrt(x, callback=save_to(x)) | |||
grad2._priority = 0 | |||
y = cos(x) | |||
with Grad("grad2") as grad2: | |||
with Grad("grad") as grad: | |||
grad2.wrt(x, callback=save_to(x)) | |||
grad.wrt(x, callback=save_to(x)) | |||
y = cos(x) | |||
grad(y, ones) | |||
z = x.grad | |||
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) | |||
grad(y, ones) | |||
z = x.grad | |||
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) | |||
x.grad = None | |||
grad2(z, ones) | |||
x.grad = None | |||
grad2(z, ones) | |||
np.testing.assert_almost_equal(x.grad.numpy(), -np.cos(x_np), decimal=5) | |||
np.testing.assert_almost_equal(x.grad.numpy(), -np.cos(x_np), decimal=5) | |||
def test_grad_with_tensor_wrapper(): | |||
x_np = np.random.rand(10).astype("float32") | |||
x = mge.Tensor(x_np) | |||
grad = Grad().wrt(x, callback=save_to(x)) | |||
with Grad() as grad: | |||
grad.wrt(x, callback=save_to(x)) | |||
y = mul(x, x) | |||
y = mul(y, y) | |||
grad(y, mge.Tensor(np.ones_like(x_np))) | |||
y = mul(x, x) | |||
y = mul(y, y) | |||
grad(y, mge.Tensor(np.ones_like(x_np))) | |||
np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) | |||
@@ -162,18 +155,21 @@ def test_release(): | |||
@check | |||
def _(): | |||
g = Grad().wrt(x) | |||
y = x * x | |||
g(y, dy) | |||
with Grad() as g: | |||
g.wrt(x) | |||
y = x * x | |||
g(y, dy) | |||
@check | |||
def _(): | |||
with Grad().wrt(x): | |||
with Grad() as g: | |||
g.wrt(x) | |||
pass | |||
@check | |||
def _(): | |||
with Grad().wrt(x): | |||
with Grad() as g: | |||
g.wrt(x) | |||
y = x * x | |||
@@ -181,12 +177,12 @@ def test_grad_inplace(): | |||
x_np = np.random.rand(10).astype("float32") | |||
x = mge.Tensor(x_np) | |||
grad = Grad().wrt(x, callback=save_to(x)) | |||
with Grad() as grad: | |||
grad.wrt(x, callback=save_to(x)) | |||
y = mul(x, x) | |||
y *= y | |||
grad(y, mge.Tensor(np.ones_like(x_np))) | |||
y = mul(x, x) | |||
y *= y | |||
grad(y, mge.Tensor(np.ones_like(x_np))) | |||
np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) | |||
@@ -196,11 +192,11 @@ def test_identity(): | |||
dy_np = np.random.rand(*x.shape).astype("float32") | |||
dy = mge.Tensor(dy_np) | |||
grad = Grad().wrt(x, callback=save_to(x)) | |||
(y,) = apply(Identity(), x) | |||
with Grad() as grad: | |||
grad.wrt(x, callback=save_to(x)) | |||
(y,) = apply(Identity(), x) | |||
grad(y, dy) | |||
grad(y, dy) | |||
np.testing.assert_array_equal(x.grad.numpy(), dy_np) | |||
@@ -220,15 +216,14 @@ def test_elemwise_add(): | |||
refs["y"] = TensorWeakRef(y) | |||
return x + y | |||
grad = Grad().wrt(x, callback=save_to(x)) | |||
z = f(x, y) | |||
del y | |||
with Grad() as grad: | |||
grad.wrt(x, callback=save_to(x)) | |||
z = f(x, y) | |||
del y | |||
for k, r in refs.items(): | |||
assert r() is None | |||
grad(z, dz) | |||
for k, r in refs.items(): | |||
assert r() is None | |||
grad(z, dz) | |||
np.testing.assert_almost_equal(x.grad.numpy(), dz_np.sum(0) * 2, decimal=5) | |||
@@ -245,13 +240,12 @@ def test_elemwise_relu(): | |||
refs["x"] = TensorWeakRef(x) | |||
return relu(x) | |||
grad = Grad().wrt(x, callback=save_to(x)) | |||
z = f(x) | |||
assert refs["x"]() is None | |||
with Grad() as grad: | |||
grad.wrt(x, callback=save_to(x)) | |||
z = f(x) | |||
assert refs["x"]() is None | |||
grad(z, dz) | |||
grad(z, dz) | |||
np.testing.assert_almost_equal(x.grad.numpy(), [2.0, 0]) | |||
@@ -269,21 +263,21 @@ def test_reshape(): | |||
x_np = np.random.rand(2, 5).astype("float32") | |||
x = mge.Tensor(x_np) | |||
grad = Grad().wrt(x, callback=save_to(x)) | |||
with Grad() as grad: | |||
grad.wrt(x, callback=save_to(x)) | |||
refs = {} | |||
refs = {} | |||
def f(x): | |||
x = x * 1 | |||
y = x.reshape(5, 2) | |||
refs["x"] = TensorWeakRef(x) | |||
return y | |||
def f(x): | |||
x = x * 1 | |||
y = x.reshape(5, 2) | |||
refs["x"] = TensorWeakRef(x) | |||
return y | |||
y = f(x) | |||
for _, r in refs.items(): | |||
assert r() is None | |||
grad(y, F.ones_like(y)) | |||
y = f(x) | |||
for _, r in refs.items(): | |||
assert r() is None | |||
grad(y, F.ones_like(y)) | |||
np.testing.assert_equal(np.ones((2, 5), dtype=np.float32), x.grad.numpy()) | |||
@@ -291,21 +285,21 @@ def test_subtensor(): | |||
x_np = np.random.rand(3, 3).astype("float32") | |||
x = mge.Tensor(x_np) | |||
grad = Grad().wrt(x, callback=save_to(x)) | |||
refs = {} | |||
with Grad() as grad: | |||
grad.wrt(x, callback=save_to(x)) | |||
refs = {} | |||
def f(x): | |||
x = x * 1 | |||
y = x[1:-1, :2] | |||
refs["x"] = TensorWeakRef(x) | |||
return y | |||
def f(x): | |||
x = x * 1 | |||
y = x[1:-1, :2] | |||
refs["x"] = TensorWeakRef(x) | |||
return y | |||
y = f(x) | |||
for _, r in refs.items(): | |||
assert r() is None | |||
y = f(x) | |||
for _, r in refs.items(): | |||
assert r() is None | |||
grad(y, F.ones_like(y)) | |||
grad(y, F.ones_like(y)) | |||
np.testing.assert_equal( | |||
np.array([[0, 0, 0], [1, 1, 0], [0, 0, 0]], dtype=np.float32), x.grad.numpy() | |||
) | |||
@@ -315,21 +309,21 @@ def test_IndexingMultiAxisVec(): | |||
x_np = np.random.rand(3, 3).astype("float32") | |||
x = mge.Tensor(x_np) | |||
grad = Grad().wrt(x, callback=save_to(x)) | |||
with Grad() as grad: | |||
grad.wrt(x, callback=save_to(x)) | |||
refs = {} | |||
refs = {} | |||
def f(x): | |||
x = x * 1 | |||
y = x[[0, 2], [0, 2]] | |||
refs["x"] = TensorWeakRef(x) | |||
return y | |||
def f(x): | |||
x = x * 1 | |||
y = x[[0, 2], [0, 2]] | |||
refs["x"] = TensorWeakRef(x) | |||
return y | |||
y = f(x) | |||
for _, r in refs.items(): | |||
assert r() is None | |||
grad(y, F.ones_like(y)) | |||
y = f(x) | |||
for _, r in refs.items(): | |||
assert r() is None | |||
grad(y, F.ones_like(y)) | |||
np.testing.assert_equal( | |||
np.array([[1, 0, 0], [0, 0, 0], [0, 0, 1]], dtype=np.float32), x.grad.numpy() | |||
) | |||
@@ -339,21 +333,21 @@ def test_AxisAddRemove(): | |||
x_np = np.random.rand(1, 5).astype("float32") | |||
x = mge.Tensor(x_np) | |||
grad = Grad().wrt(x, callback=save_to(x)) | |||
refs = {} | |||
with Grad() as grad: | |||
grad.wrt(x, callback=save_to(x)) | |||
refs = {} | |||
def f(x): | |||
x = x * 1 | |||
y = F.squeeze(F.expand_dims(x, 2), 0) | |||
refs["x"] = TensorWeakRef(x) | |||
return y | |||
def f(x): | |||
x = x * 1 | |||
y = F.squeeze(F.expand_dims(x, 2), 0) | |||
refs["x"] = TensorWeakRef(x) | |||
return y | |||
y = f(x) | |||
for _, r in refs.items(): | |||
assert r() is None | |||
y = f(x) | |||
for _, r in refs.items(): | |||
assert r() is None | |||
grad(y, F.ones_like(y)) | |||
grad(y, F.ones_like(y)) | |||
np.testing.assert_equal( | |||
np.array([[1, 1, 1, 1, 1]], dtype=np.float32), x.grad.numpy() | |||
) | |||
@@ -363,10 +357,11 @@ def test_Broadcast(): | |||
x_np = np.random.rand(3, 3, 1).astype("float32") | |||
x = mge.Tensor(x_np) | |||
grad = Grad().wrt(x, callback=save_to(x)) | |||
y = F.broadcast_to(x, (3, 3, 10)) | |||
with Grad() as grad: | |||
grad.wrt(x, callback=save_to(x)) | |||
y = F.broadcast_to(x, (3, 3, 10)) | |||
grad(y, F.ones_like(y)) | |||
grad(y, F.ones_like(y)) | |||
np.testing.assert_equal(np.ones((3, 3, 1), dtype=np.float32) * 10, x.grad.numpy()) | |||
@@ -374,10 +369,11 @@ def test_interpolate_fastpath(): | |||
x_np = np.random.rand(3, 3, 32, 32).astype("float32") | |||
x = mge.Tensor(x_np) | |||
grad = Grad().wrt(x, callback=save_to(x)) | |||
y = F.vision.interpolate(x, size=(16, 16), mode="bilinear") | |||
with Grad() as grad: | |||
grad.wrt(x, callback=save_to(x)) | |||
y = F.vision.interpolate(x, size=(16, 16), mode="bilinear") | |||
grad(y, F.ones_like(y)) | |||
grad(y, F.ones_like(y)) | |||
np.testing.assert_equal(np.ones(x_np.shape, dtype=np.float32) / 4, x.grad.numpy()) | |||
@@ -385,10 +381,11 @@ def test_Reduce_sum(): | |||
x_np = np.random.rand(3, 3).astype("float32") | |||
x = mge.Tensor(x_np) | |||
grad = Grad().wrt(x, callback=save_to(x)) | |||
y = x.sum(axis=0) | |||
with Grad() as grad: | |||
grad.wrt(x, callback=save_to(x)) | |||
y = x.sum(axis=0) | |||
grad(y, F.ones_like(y)) | |||
grad(y, F.ones_like(y)) | |||
np.testing.assert_equal(np.ones((3, 3), dtype=np.float32), x.grad.numpy()) | |||
@@ -396,10 +393,11 @@ def test_Reduce_mean(): | |||
x_np = np.random.rand(3, 3).astype("float32") | |||
x = mge.Tensor(x_np) | |||
grad = Grad().wrt(x, callback=save_to(x)) | |||
y = x.mean(axis=0) | |||
with Grad() as grad: | |||
grad.wrt(x, callback=save_to(x)) | |||
y = x.mean(axis=0) | |||
grad(y, F.ones_like(y)) | |||
grad(y, F.ones_like(y)) | |||
np.testing.assert_equal(np.ones((3, 3), dtype=np.float32) / 3, x.grad.numpy()) | |||
@@ -407,21 +405,21 @@ def test_addAxis(): | |||
x_np = np.random.rand(3, 3).astype("float32") | |||
x = mge.Tensor(x_np) | |||
grad = Grad().wrt(x, callback=save_to(x)) | |||
with Grad() as grad: | |||
grad.wrt(x, callback=save_to(x)) | |||
refs = {} | |||
refs = {} | |||
def f(x): | |||
x = x * 1 | |||
y = F.expand_dims(x, [2, 3]) | |||
refs["x"] = TensorWeakRef(x) | |||
return y | |||
def f(x): | |||
x = x * 1 | |||
y = F.expand_dims(x, [2, 3]) | |||
refs["x"] = TensorWeakRef(x) | |||
return y | |||
y = f(x) | |||
for _, r in refs.items(): | |||
assert r() is None | |||
y = f(x) | |||
for _, r in refs.items(): | |||
assert r() is None | |||
grad(y, F.ones_like(y)) | |||
grad(y, F.ones_like(y)) | |||
np.testing.assert_equal(np.ones((3, 3), dtype=np.float32), x.grad.numpy()) | |||
@@ -429,21 +427,21 @@ def test_removeAxis(): | |||
x_np = np.random.rand(3, 3, 1, 1).astype("float32") | |||
x = mge.Tensor(x_np) | |||
grad = Grad().wrt(x, callback=save_to(x)) | |||
with Grad() as grad: | |||
grad.wrt(x, callback=save_to(x)) | |||
refs = {} | |||
refs = {} | |||
def f(x): | |||
x = x * 1 | |||
y = F.squeeze(x, [2, 3]) | |||
refs["x"] = TensorWeakRef(x) | |||
return y | |||
def f(x): | |||
x = x * 1 | |||
y = F.squeeze(x, [2, 3]) | |||
refs["x"] = TensorWeakRef(x) | |||
return y | |||
y = f(x) | |||
for _, r in refs.items(): | |||
assert r() is None | |||
y = f(x) | |||
for _, r in refs.items(): | |||
assert r() is None | |||
grad(y, F.ones_like(y)) | |||
grad(y, F.ones_like(y)) | |||
np.testing.assert_equal(np.ones((3, 3, 1, 1), dtype=np.float32), x.grad.numpy()) | |||
@@ -452,11 +450,14 @@ def test_dot(): | |||
x = mge.Tensor(x) | |||
u = F.ones((2,)) | |||
v = F.ones((2,)) | |||
grad = Grad().wrt(x, callback=save_to(x)) | |||
def f(x): | |||
return F.dot(u, F.matmul(x, v)) | |||
with Grad() as grad: | |||
grad.wrt(x, callback=save_to(x)) | |||
def f(x): | |||
return F.dot(u, F.matmul(x, v)) | |||
y = f(x) | |||
grad(y, F.ones_like(y)) | |||
y = f(x) | |||
grad(y, F.ones_like(y)) | |||
np.testing.assert_equal(np.ones((2, 2), dtype=np.float32), x.grad.numpy()) |
@@ -267,25 +267,27 @@ def _gen_roi_inp(): | |||
def test_roi_align(): | |||
inp_feat, rois = _gen_roi_inp() | |||
grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat)) | |||
output_shape = (7, 7) | |||
out_feat = F.vision.roi_align( | |||
inp_feat, | |||
rois, | |||
output_shape=output_shape, | |||
mode="average", | |||
spatial_scale=1.0 / 4, | |||
sample_points=2, | |||
aligned=True, | |||
) | |||
assert make_shape_tuple(out_feat.shape) == ( | |||
rois.shape[0], | |||
inp_feat.shape[1], | |||
*output_shape, | |||
) | |||
with Grad() as grad: | |||
grad.wrt(inp_feat, callback=_save_to(inp_feat)) | |||
output_shape = (7, 7) | |||
out_feat = F.vision.roi_align( | |||
inp_feat, | |||
rois, | |||
output_shape=output_shape, | |||
mode="average", | |||
spatial_scale=1.0 / 4, | |||
sample_points=2, | |||
aligned=True, | |||
) | |||
assert make_shape_tuple(out_feat.shape) == ( | |||
rois.shape[0], | |||
inp_feat.shape[1], | |||
*output_shape, | |||
) | |||
grad(out_feat, tensor(F.ones_like(out_feat))) | |||
grad(out_feat, tensor(F.ones_like(out_feat))) | |||
assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape) | |||
@@ -307,20 +309,23 @@ def _gen_correlation(random=True, constant=1, image_shape=(2, 1, 160, 160)): | |||
def test_correlation(): | |||
##test case 0 check the grad shape | |||
data1, data2 = _gen_correlation() | |||
grad = Grad().wrt(data1, callback=_save_to(data1)) | |||
out_feat = F.vision.correlation( | |||
data1, | |||
data2, | |||
kernel_size=5, | |||
max_displacement=4, | |||
stride1=2, | |||
stride2=2, | |||
pad_size=2, | |||
is_multiply=True, | |||
) | |||
with Grad() as grad: | |||
grad.wrt(data1, callback=_save_to(data1)) | |||
out_feat = F.vision.correlation( | |||
data1, | |||
data2, | |||
kernel_size=5, | |||
max_displacement=4, | |||
stride1=2, | |||
stride2=2, | |||
pad_size=2, | |||
is_multiply=True, | |||
) | |||
grad(out_feat, tensor(F.ones_like(out_feat))) | |||
grad(out_feat, tensor(F.ones_like(out_feat))) | |||
assert make_shape_tuple(data1.grad.shape) == make_shape_tuple(data1.shape) | |||
##test case 1 from https://github.com/NVIDIA/flownet2-pytorch/issues/194 | |||
@@ -391,32 +396,36 @@ def test_correlation(): | |||
def test_roi_pooling(): | |||
inp_feat, rois = _gen_roi_inp() | |||
grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat)) | |||
output_shape = (7, 7) | |||
out_feat = F.vision.roi_pooling( | |||
inp_feat, rois, output_shape=output_shape, mode="max", scale=1.0 / 4, | |||
) | |||
assert make_shape_tuple(out_feat.shape) == ( | |||
rois.shape[0], | |||
inp_feat.shape[1], | |||
*output_shape, | |||
) | |||
with Grad() as grad: | |||
grad.wrt(inp_feat, callback=_save_to(inp_feat)) | |||
output_shape = (7, 7) | |||
out_feat = F.vision.roi_pooling( | |||
inp_feat, rois, output_shape=output_shape, mode="max", scale=1.0 / 4, | |||
) | |||
assert make_shape_tuple(out_feat.shape) == ( | |||
rois.shape[0], | |||
inp_feat.shape[1], | |||
*output_shape, | |||
) | |||
grad(out_feat, tensor(F.ones_like(out_feat))) | |||
grad(out_feat, tensor(F.ones_like(out_feat))) | |||
assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape) | |||
def test_adaptive_avg_pool2d(): | |||
inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4)) | |||
oshp = (2, 2) | |||
grad = Grad().wrt(inp, callback=_save_to(inp)) | |||
outp = F.adaptive_avg_pool2d(inp, oshp,) | |||
assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,) | |||
np.testing.assert_equal( | |||
outp.numpy(), np.array([[[[2.5, 4.5], [10.5, 12.5]]]], dtype=np.float32) | |||
) | |||
with Grad() as grad: | |||
grad.wrt(inp, callback=_save_to(inp)) | |||
outp = F.adaptive_avg_pool2d(inp, oshp,) | |||
assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,) | |||
np.testing.assert_equal( | |||
outp.numpy(), np.array([[[[2.5, 4.5], [10.5, 12.5]]]], dtype=np.float32) | |||
) | |||
grad(outp, tensor(F.ones_like(outp))) | |||
grad(outp, tensor(F.ones_like(outp))) | |||
assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape) | |||
np.testing.assert_equal( | |||
inp.grad.numpy(), | |||
@@ -439,14 +448,16 @@ def test_adaptive_avg_pool2d(): | |||
def test_adaptive_max_pool2d(): | |||
inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4)) | |||
oshp = (2, 2) | |||
grad = Grad().wrt(inp, callback=_save_to(inp)) | |||
outp = F.adaptive_max_pool2d(inp, oshp,) | |||
assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,) | |||
np.testing.assert_equal( | |||
outp.numpy(), np.array([[[[5, 7], [13, 15]]]], dtype=np.float32) | |||
) | |||
with Grad() as grad: | |||
grad.wrt(inp, callback=_save_to(inp)) | |||
outp = F.adaptive_max_pool2d(inp, oshp,) | |||
assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,) | |||
np.testing.assert_equal( | |||
outp.numpy(), np.array([[[[5, 7], [13, 15]]]], dtype=np.float32) | |||
) | |||
grad(outp, tensor(F.ones_like(outp))) | |||
grad(outp, tensor(F.ones_like(outp))) | |||
assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape) | |||
np.testing.assert_equal( | |||
inp.grad.numpy(), | |||
@@ -351,7 +351,7 @@ def test_expand_dims_for_scalar(): | |||
for axis in [1, -2, (1, 2), (-2, -3)]: | |||
np.testing.assert_raises(np.AxisError, np.expand_dims, x, axis) | |||
np.testing.assert_raises(AssertionError, F.expand_dims, xx, axis) | |||
np.testing.assert_raises(RuntimeError, F.expand_dims, xx, axis) | |||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||
@@ -9,6 +9,7 @@ | |||
import inspect | |||
import io | |||
import itertools | |||
import random | |||
from tempfile import mkstemp | |||
import numpy as np | |||
@@ -25,7 +26,7 @@ from megengine.core.ops import builtin as ops | |||
from megengine.core.ops.builtin import Elemwise | |||
from megengine.core.tensor.utils import isscalar | |||
from megengine.functional import exp, log | |||
from megengine.jit import GraphOptimizationConfig, exclude_from_trace, trace | |||
from megengine.jit import GraphOptimizationConfig, TraceError, exclude_from_trace, trace | |||
from megengine.module import Module | |||
from megengine.random import normal, uniform | |||
from megengine.utils.naming import AutoNaming | |||
@@ -464,36 +465,92 @@ def test_trace_warp_perspective(): | |||
f(x, M) | |||
def test_raise_on_trace(): | |||
step_count = 0 | |||
catch_count = 0 | |||
bad_step = 10 | |||
@pytest.mark.parametrize( | |||
"normal_expr, mismatch_expr, reason", | |||
[ | |||
("a + b + c", "a + b - c", "operator mismatch"), | |||
("a + b + 1", "a + b + 2", "tensors not equals"), | |||
("((a + b), (b + c))[0]", "a + b", "mismature end"), | |||
("a + b + c", "c + (a + b)", "expect internal node, got external"), | |||
("c + (a + b)", "a + b + c", "expect external node, got internal"), | |||
("a + b + c", "a + b + c + c", "too many instructions"), | |||
("((a + b), (b + c))[1]", "((a + b), (b + c))[0]", "data unreadable"), | |||
("((a + b), (b + c))[1] + a", "((a + b), (b + c))[0] + a", "input id mismatch"), | |||
], | |||
) | |||
def test_trace_mismatch(normal_expr, mismatch_expr, reason): | |||
a = tensor([1, 2, 3, 4]) | |||
b = tensor([5, 6, 7, 8]) | |||
c = tensor([9, 0, 1, 2]) | |||
mismatch = False | |||
@trace(symbolic=True) | |||
def fn(a, b, c): | |||
if not mismatch: | |||
result = eval(normal_expr) | |||
else: | |||
result = eval(mismatch_expr) | |||
return result | |||
for i in range(20): | |||
try: | |||
d = fn(a, b, c) | |||
except TraceError as e: | |||
assert mismatch | |||
assert str(e) == "trace error because {}".format(reason) | |||
except: | |||
pytest.fail("unexpected trace error") | |||
else: | |||
assert not mismatch | |||
np.testing.assert_equal(d.numpy(), eval(normal_expr).numpy()) | |||
mismatch = random.random() > 0.8 | |||
class CatchMe(Exception): | |||
pass | |||
def test_exception_in_trace(): | |||
a = tensor([1, 2, 3, 4]) | |||
b = tensor([5, 6, 7, 8]) | |||
c = tensor([9, 0, 1, 2]) | |||
@trace | |||
def add_abc(a, b, c): | |||
ps = a + b | |||
result = ps + c | |||
if step_count == bad_step: | |||
raise CatchMe("catch me") | |||
mismatch = False | |||
exc = Exception() | |||
@trace(symbolic=True) | |||
def fn(a, b, c): | |||
result = a + b | |||
if not mismatch: | |||
result += c | |||
else: | |||
raise exc | |||
return result | |||
for i in range(100): | |||
for i in range(20): | |||
try: | |||
d = add_abc(a, b, c) | |||
except CatchMe as e: | |||
catch_count += 1 | |||
d = fn(a, b, c) | |||
except TraceError as e: | |||
pytest.fail("unexpected trace error") | |||
except Exception as e: | |||
assert mismatch | |||
assert e is exc | |||
else: | |||
assert not mismatch | |||
np.testing.assert_equal(d.numpy(), (a + b + c).numpy()) | |||
step_count += 1 | |||
mismatch = random.random() > 0.8 | |||
assert catch_count == 1 | |||
def test_graph_error(): | |||
a = tensor(np.arange(8).reshape((2, 4))) | |||
b = tensor(np.arange(8).reshape((2, 4))) | |||
@trace(symbolic=True) | |||
def fn(a, b): | |||
return a + b | |||
fn(a, b) | |||
with pytest.raises(RuntimeError): | |||
fn(a, b.transpose()) | |||
fn(a, b) | |||
@pytest.mark.parametrize("trace_mode", [False, True]) | |||
@@ -653,9 +710,10 @@ def test_trace_jit_config(): | |||
x = tensor(2) | |||
y = func(x) | |||
func._compile() | |||
y = func(x) | |||
# func._compile() | |||
options = func._graph.options | |||
options = func._trace.options | |||
mapping = {None: 0, False: 1, True: 2} | |||
assert options.graph_opt.jit == 0 | |||
assert options.graph_opt.jit_config.fuse_dimshuffle == mapping[fuse_dimshuffle] | |||
@@ -82,9 +82,10 @@ def test_tqt(): | |||
x = mge.tensor(x, dtype="float32") | |||
s = mge.tensor(s, dtype="float32") | |||
g_y = mge.tensor(g_y, dtype="float32") | |||
grad = Grad().wrt(x, s, callback=cb) | |||
y = tqt_forward(-127, 127, x, s) | |||
grad(y, g_y) | |||
with Grad() as grad: | |||
grad.wrt(x, s, callback=cb) | |||
y = tqt_forward(-127, 127, x, s) | |||
grad(y, g_y) | |||
g_x, g_s = g | |||
np.testing.assert_allclose(y.numpy(), y_np, rtol=1e-5, atol=1e-5) | |||
@@ -131,14 +132,16 @@ def test_fakequant(): | |||
# test backward | |||
x = tensor(inp_data, dtype=np.float32) | |||
grad = Grad().wrt(x, callback=_save_to(x)) | |||
y = fake_quant_tensor(x, qparams) | |||
grad(y, tensor(F.ones_like(x))) | |||
with Grad() as grad: | |||
grad.wrt(x, callback=_save_to(x)) | |||
y = fake_quant_tensor(x, qparams) | |||
grad(y, tensor(F.ones_like(x))) | |||
x1 = tensor(inp_data, dtype=np.float32) | |||
grad = Grad().wrt(x1, callback=_save_to(x1)) | |||
y1 = fake_quant_tensor_gt(x1, scale, zero_point, qmin, qmax) | |||
grad(y1, tensor(F.ones_like(x1))) | |||
with Grad() as grad: | |||
grad.wrt(x1, callback=_save_to(x1)) | |||
y1 = fake_quant_tensor_gt(x1, scale, zero_point, qmin, qmax) | |||
grad(y1, tensor(F.ones_like(x1))) | |||
assert np.allclose(x.grad.numpy(), x1.grad.numpy()) | |||
assert make_shape_tuple(x.grad.shape) == make_shape_tuple(x1.grad.shape) | |||
@@ -237,9 +240,10 @@ def test_lsq(): | |||
grad_s = mge.tensor(grad_s, dtype="float32") | |||
g_y = mge.tensor(g_y, dtype="float32") | |||
grad = Grad().wrt(x, s, callback=cb) | |||
y = lsq_forward(-127, 127, x, s, zero_point, grad_s) | |||
grad(y, g_y) | |||
with Grad() as grad: | |||
grad.wrt(x, s, callback=cb) | |||
y = lsq_forward(-127, 127, x, s, zero_point, grad_s) | |||
grad(y, g_y) | |||
g_x, g_s = g | |||
np.testing.assert_allclose(y.numpy(), y_np, rtol=1e-7, atol=1e-7) | |||
@@ -430,9 +430,10 @@ def test_ShuffleRNG(): | |||
n, m = 6, 3 | |||
arr = np.arange(n * m) | |||
out0 = Tensor(arr, dtype="float32") | |||
grad = Grad().wrt(out0, callback=cb) | |||
random.shuffle(out0) | |||
grad(out0, F.ones_like(out0)) | |||
with Grad() as grad: | |||
grad.wrt(out0, callback=cb) | |||
random.shuffle(out0) | |||
grad(out0, F.ones_like(out0)) | |||
m1 = RNG(seed=111, device="xpu0") | |||
m2 = RNG(seed=111, device="xpu1") | |||
m3 = RNG(seed=222, device="xpu0") | |||