Browse Source

refactor(dispatch): switch to new dispatch system

GitOrigin-RevId: 32dd49a23a
tags/v1.8.0
Megvii Engine Team 3 years ago
parent
commit
0bdd0b1467
46 changed files with 1622 additions and 3613 deletions
  1. +11
    -23
      imperative/python/megengine/autodiff/grad_manager.py
  2. +98
    -31
      imperative/python/megengine/core/autodiff/grad.py
  3. +1
    -13
      imperative/python/megengine/core/tensor/array_method.py
  4. +3
    -14
      imperative/python/megengine/core/tensor/indexing.py
  5. +2
    -5
      imperative/python/megengine/core/tensor/megbrain_graph.py
  6. +1
    -10
      imperative/python/megengine/core/tensor/utils.py
  7. +8
    -28
      imperative/python/megengine/distributed/functional.py
  8. +0
    -3
      imperative/python/megengine/distributed/helper.py
  9. +0
    -25
      imperative/python/megengine/experimental/autograd.py
  10. +0
    -3
      imperative/python/megengine/functional/inplace.py
  11. +1
    -3
      imperative/python/megengine/functional/math.py
  12. +0
    -1
      imperative/python/megengine/functional/nn.py
  13. +0
    -9
      imperative/python/megengine/functional/tensor.py
  14. +1
    -13
      imperative/python/megengine/jit/__init__.py
  15. +161
    -962
      imperative/python/megengine/jit/tracing.py
  16. +35
    -6
      imperative/python/megengine/module/module.py
  17. +5
    -0
      imperative/python/megengine/random/rng.py
  18. +20
    -7
      imperative/python/megengine/tensor.py
  19. +0
    -3
      imperative/python/megengine/traced_module/__init__.py
  20. +0
    -3
      imperative/python/megengine/traced_module/checker.py
  21. +11
    -2
      imperative/python/megengine/traced_module/expr.py
  22. +0
    -2
      imperative/python/megengine/traced_module/module_tracer.py
  23. +2
    -6
      imperative/python/megengine/traced_module/traced_module.py
  24. +5
    -0
      imperative/python/megengine/utils/profiler.py
  25. +59
    -625
      imperative/python/src/grad.cpp
  26. +15
    -141
      imperative/python/src/grad.h
  27. +0
    -43
      imperative/python/src/grad_info.h
  28. +237
    -164
      imperative/python/src/grad_override.cpp
  29. +0
    -245
      imperative/python/src/intrusive_list.h
  30. +0
    -42
      imperative/python/src/module_trace.cpp
  31. +41
    -1
      imperative/python/src/module_trace.h
  32. +2
    -7
      imperative/python/src/ops.cpp
  33. +493
    -541
      imperative/python/src/tensor.cpp
  34. +38
    -216
      imperative/python/src/tensor.h
  35. +0
    -63
      imperative/python/src/trace.cpp
  36. +0
    -28
      imperative/python/src/trace.h
  37. +0
    -49
      imperative/python/src/trace_info.h
  38. +0
    -14
      imperative/python/test/conftest.py
  39. +1
    -1
      imperative/python/test/integration/test_trace_dump.py
  40. +43
    -8
      imperative/python/test/unit/autodiff/test_grad_manger.py
  41. +162
    -161
      imperative/python/test/unit/core/test_autodiff.py
  42. +66
    -55
      imperative/python/test/unit/functional/test_functional.py
  43. +1
    -1
      imperative/python/test/unit/functional/test_tensor.py
  44. +79
    -21
      imperative/python/test/unit/jit/test_tracing.py
  45. +16
    -12
      imperative/python/test/unit/quantization/test_fake_quant.py
  46. +4
    -3
      imperative/python/test/unit/random/test_rng.py

+ 11
- 23
imperative/python/megengine/autodiff/grad_manager.py View File

@@ -28,9 +28,6 @@ class AttachSpec:
__slots__ = "tensor", "callbacks" __slots__ = "tensor", "callbacks"




_global_priority = 0


class GradManager: class GradManager:
r"""GradManager computes gradients or more generally, vector-Jacobian product, by reverse mode r"""GradManager computes gradients or more generally, vector-Jacobian product, by reverse mode
automatic differentiation (a.k.a. back propagation). automatic differentiation (a.k.a. back propagation).
@@ -127,7 +124,6 @@ class GradManager:
self._grad = None self._grad = None
self._after_backward_callback = [] self._after_backward_callback = []
self._gradients = {} self._gradients = {}
self._priority = None


def attached_tensors(self): def attached_tensors(self):
r"""Return attached tensor list from :meth:`attach`.""" r"""Return attached tensor list from :meth:`attach`."""
@@ -299,31 +295,25 @@ class GradManager:
tensor.grad = grad tensor.grad = grad
else: else:
tensor.grad += grad tensor.grad += grad
if tensor._isscalar() and tensor.grad is not None:
tensor.grad._setscalar()
finally: finally:
self.release() self.release()
backwarding_grad_manager = cache backwarding_grad_manager = cache
set_option("record_computing_path", 1)
pop_scope("backward")
set_option("record_computing_path", 1)
pop_scope("backward")


def record(self): def record(self):
r"""Start recording operations r"""Start recording operations


After this call, you will be able to call :meth:`backward`. After this call, you will be able to call :meth:`backward`.
""" """
global _global_priority
if self._recording: if self._recording:
raise RuntimeError("already recording") raise RuntimeError("already recording")
grad = Grad() grad = Grad()
self._recording = True self._recording = True
self._grad = grad self._grad = grad
grad.__enter__()
for spec in self._attach_specs.values(): for spec in self._attach_specs.values():
self._do_record(spec) self._do_record(spec)
if self._priority is None:
grad._priority = _global_priority
_global_priority -= 1
grad.__enter__()


def _do_record(self, spec): def _do_record(self, spec):
tensor = spec.tensor() tensor = spec.tensor()
@@ -331,6 +321,8 @@ class GradManager:
return return


def callback(grad, callbacks=spec.callbacks): def callback(grad, callbacks=spec.callbacks):
from ..functional import ones_like

for cb in callbacks: for cb in callbacks:
grad = cb(tensor, grad) grad = cb(tensor, grad)
self._gradients[id(tensor)] = grad self._gradients[id(tensor)] = grad
@@ -343,14 +335,11 @@ class GradManager:


After this call, you will not be able to call :meth:`backward`. After this call, you will not be able to call :meth:`backward`.
""" """
global _global_priority
if self._grad is not None: if self._grad is not None:
self._grad.__exit__(None, None, None) self._grad.__exit__(None, None, None)
self._grad = None self._grad = None
self._recording = False self._recording = False
self._gradients = dict() self._gradients = dict()
if self._priority is None:
_global_priority += 1


def __enter__(self): def __enter__(self):
self.record() self.record()
@@ -382,15 +371,14 @@ class GradManagerGroup:
__ror__ = merge_with __ror__ = merge_with


def __enter__(self): def __enter__(self):
global _global_priority
_global_priority += 1
Grad.stack.append([])
Grad.begin_group()
for gm in self._gms: for gm in self._gms:
gm._priority = _global_priority
gm.record() gm.record()
assert gm._grad is not None
Grad.end_group()


def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
global _global_priority
_global_priority -= 1
for gm in self._gms:
for gm in reversed(self._gms):
gm.release() gm.release()
gm._priority = None
assert gm._grad is None

+ 98
- 31
imperative/python/megengine/core/autodiff/grad.py View File

@@ -6,17 +6,9 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools
import heapq
import itertools
import typing
import weakref import weakref


import numpy as np

from .._imperative_rt import core2, ops
from ..ops.builtin import Elemwise, OpDef, RemoteSend
from ..ops.special import Const
from .._imperative_rt import core2


_grad_count = 0 _grad_count = 0
_grad_manager_dict = weakref.WeakValueDictionary() _grad_manager_dict = weakref.WeakValueDictionary()
@@ -36,6 +28,10 @@ class GradKey(core2.GradKey):




class Grad: class Grad:
stack = []
grouping = False
key2grad = weakref.WeakValueDictionary()

def __init__(self, name=None): def __init__(self, name=None):
global _grad_count global _grad_count
if name is None: if name is None:
@@ -43,15 +39,9 @@ class Grad:
_grad_count += 1 _grad_count += 1
self._refkeeper = [] self._refkeeper = []
self._impl = GradKey(name) self._impl = GradKey(name)
Grad.key2grad[self._impl] = self
_grad_manager_dict[self._name] = self _grad_manager_dict[self._name] = self

@property
def _priority(self):
return self._impl.priority

@_priority.setter
def _priority(self, priority):
self._impl.priority = priority
self._group = [weakref.ref(self)]


@property @property
def _name(self): def _name(self):
@@ -70,33 +60,80 @@ class Grad:


if not isinstance(ys, Sequence): if not isinstance(ys, Sequence):
ys = [ys] ys = [ys]

if not isinstance(dys, Sequence): if not isinstance(dys, Sequence):
dys = [dys] dys = [dys]


group = [ref() for ref in self._group]

for grad in group:
if grad is self:
continue
grad.suppress()

self._impl.backward(ys, dys) self._impl.backward(ys, dys)


for grad in group:
if grad is self:
continue
grad.resume()

self._refkeeper = None self._refkeeper = None
return None


def __enter__(self): def __enter__(self):
ref = weakref.ref(self)
self._impl.enter()
if Grad.grouping:
group = Grad.stack[-1]
self._group = group
group.append(ref)
else:
Grad.stack.append(self._group)
return self return self


def __exit__(self, _1, _2, _3): def __exit__(self, _1, _2, _3):
self._impl.exit()
self._refkeeper = None self._refkeeper = None
del self._impl


class Function(ops.PyOpBase):
del Grad.key2grad[self._impl]
self._impl = None
self._group.remove(weakref.ref(self))
if len(self._group) == 0:
Grad.stack.remove(self._group)

@staticmethod
def begin_group():
assert not Grad.grouping
Grad.grouping = True

@staticmethod
def end_group():
group = Grad.stack[-1]
assert len(group) > 0
assert Grad.grouping
Grad.grouping = False

def suppress(self):
if self._impl is not None:
self._impl.suppress()

def resume(self):
if self._impl is not None:
self._impl.resume()


class Function:
r"""Defines a block of operations with customizable differentiation. r"""Defines a block of operations with customizable differentiation.
The computation should be defined in ``forward`` method, with gradient The computation should be defined in ``forward`` method, with gradient
computation defined in ``backward`` method. computation defined in ``backward`` method.
Each instance of ``Function`` should be used only once during forwardding. Each instance of ``Function`` should be used only once during forwardding.
Examples: Examples:
.. code-block:: .. code-block::
class Sigmoid(Function): class Sigmoid(Function):
def forward(self, x): def forward(self, x):
y = 1 / (1 + F.exp(-x)) y = 1 / (1 + F.exp(-x))
@@ -115,7 +152,7 @@ class Function(ops.PyOpBase):


Returns: Returns:
a tuple of Tensor or a single Tensor. a tuple of Tensor or a single Tensor.
Note: Note:
* This method should return a tuple of Tensor or a single Tensor representing the output * This method should return a tuple of Tensor or a single Tensor representing the output
of the function. of the function.
@@ -128,7 +165,7 @@ class Function(ops.PyOpBase):


Args: Args:
output_grads: gradients of outputs that are returned by :meth:`forward`. output_grads: gradients of outputs that are returned by :meth:`forward`.
Note: Note:
* In case when some tensors of outputs are not related to loss function, the corresponding * In case when some tensors of outputs are not related to loss function, the corresponding
values in ``output_grads`` would be ``None``. values in ``output_grads`` would be ``None``.
@@ -148,10 +185,40 @@ class Function(ops.PyOpBase):
return self._default_rule(*args), self.backward return self._default_rule(*args), self.backward


def __call__(self, *args): def __call__(self, *args):
ret = core2.apply(self, *args)
for arg in args:
if not isinstance(arg, core2.Tensor):
raise TypeError(
"op Function expect type Tensor as inputs, got {}".format(type(arg))
)

grad_key = core2.get_grad_key(args)
if grad_key is None:
return self._default_rule(*args)

grad = Grad.key2grad[grad_key]
group = [ref() for ref in grad._group]

for grad in group:
grad.suppress()
outputs, backward = self._grad_rule(*args)
for grad in reversed(group):
grad.resume()

def normalized_backward(*output_grads):
input_grads = backward(*output_grads)
if isinstance(input_grads, core2.Tensor) or input_grads is None:
input_grads = (input_grads,)
return input_grads

if self.__single_output: if self.__single_output:
(ret,) = ret
return ret
outputs = (outputs,)
for grad in reversed(group):
if grad._impl is None:
continue
outputs = core2.set_grad(grad._impl, normalized_backward, args, outputs)
if self.__single_output:
(outputs,) = outputs
return outputs


def __getstate__(self): def __getstate__(self):
return self.__dict__ return self.__dict__


+ 1
- 13
imperative/python/megengine/core/tensor/array_method.py View File

@@ -26,7 +26,6 @@ from .utils import (
convert_inputs, convert_inputs,
isscalar, isscalar,
make_shape_tuple, make_shape_tuple,
setscalar,
) )


_ElwMod = builtin.Elemwise.Mode _ElwMod = builtin.Elemwise.Mode
@@ -34,14 +33,7 @@ _ElwMod = builtin.Elemwise.Mode


def _elwise_apply(args, mode): def _elwise_apply(args, mode):
op = builtin.Elemwise(mode) op = builtin.Elemwise(mode)
_isscalar = True
for i in args:
if isscalar(i) == False:
_isscalar = False
break
(result,) = apply(op, *args) (result,) = apply(op, *args)
if _isscalar:
setscalar(result)
return result return result




@@ -203,8 +195,6 @@ def _remove_axis(inp: Tensor, axis) -> Tensor:


op = builtin.RemoveAxis(axis=axis) op = builtin.RemoveAxis(axis=axis)
(result,) = apply(op, inp) (result,) = apply(op, inp)
if len(axis) == inp.ndim:
setscalar(result)
return result return result




@@ -221,6 +211,7 @@ def _reduce(mode):


op = builtin.Reduce(mode=mode, axis=0) op = builtin.Reduce(mode=mode, axis=0)
(result,) = apply(op, data) (result,) = apply(op, data)
result = _remove_axis(result, 0)
elif isinstance(axis, collections.abc.Iterable): elif isinstance(axis, collections.abc.Iterable):
axis = _normalize_axis(self.ndim, axis, reverse=True) axis = _normalize_axis(self.ndim, axis, reverse=True)
for ai in axis: for ai in axis:
@@ -239,8 +230,6 @@ def _reduce(mode):
if self.dtype == np.bool_: if self.dtype == np.bool_:
if mode in ["min", "max"]: if mode in ["min", "max"]:
result = result.astype("bool") result = result.astype("bool")
if axis is None or self.ndim == 1:
setscalar(result)
return result return result


return f return f
@@ -457,7 +446,6 @@ class ArrayMethodMixin(abc.ABC):
len(args) == 0 len(args) == 0
), "transpose for scalar does not accept additional args" ), "transpose for scalar does not accept additional args"
ret = self.to(self.device) ret = self.to(self.device)
setscalar(ret)
return ret return ret
if not args: if not args:
args = range(self.ndim)[::-1] args = range(self.ndim)[::-1]


+ 3
- 14
imperative/python/megengine/core/tensor/indexing.py View File

@@ -111,7 +111,6 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
if not isinstance(tuple_val, tuple): if not isinstance(tuple_val, tuple):
tuple_val = (tuple_val,) tuple_val = (tuple_val,)
ndim_indexed = 0 ndim_indexed = 0
ndim_indexed_scalar = 0
for i in tuple_val: for i in tuple_val:
if not i is Ellipsis: if not i is Ellipsis:
ndim_indexed += ( ndim_indexed += (
@@ -119,14 +118,6 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
if hasattr(i, "dtype") and i.dtype == np.bool_ and hasattr(i, "ndim") if hasattr(i, "dtype") and i.dtype == np.bool_ and hasattr(i, "ndim")
else 1 else 1
) )
if isscalar(i):
ndim_indexed_scalar += 1
ret_scalar = False
try:
ret_scalar = ndim_indexed_scalar == inp.ndim
except ValueError:
# inp.ndim is unknown
pass
else: else:
if ndim_indexed > inp.ndim: if ndim_indexed > inp.ndim:
raise IndexError( raise IndexError(
@@ -221,7 +212,7 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
items.append(item) items.append(item)
if new_axes: if new_axes:
raise IndexError("newaxis is not allowed here") raise IndexError("newaxis is not allowed here")
return inp, tensors, items, use_subtensor, ret_scalar
return inp, tensors, items, use_subtensor




def try_condtake(tensor, index): def try_condtake(tensor, index):
@@ -247,14 +238,12 @@ def getitem(tensor, index):
try_result = try_condtake(tensor, index) try_result = try_condtake(tensor, index)
if len(try_result) == 2: if len(try_result) == 2:
return try_result[0] return try_result[0]
tensor, tensors, items, use_subtensor, ret_scalar = unpack_getitem(tensor, index)
tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index)
if use_subtensor: if use_subtensor:
op = builtin.Subtensor(items=items) op = builtin.Subtensor(items=items)
else: else:
op = builtin.IndexingMultiAxisVec(items=items) op = builtin.IndexingMultiAxisVec(items=items)
(result,) = apply(op, tensor, *tensors) (result,) = apply(op, tensor, *tensors)
if ret_scalar:
result._setscalar()
return result return result




@@ -266,7 +255,7 @@ def setitem(tensor, index, value):
tensor = tensor.reshape(-1) tensor = tensor.reshape(-1)
if not isinstance(value, (Tensor, SymbolVar)): if not isinstance(value, (Tensor, SymbolVar)):
(value,) = Const(value, dtype=tensor.dtype, device=tensor.device)(tensor) (value,) = Const(value, dtype=tensor.dtype, device=tensor.device)(tensor)
tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index)
tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index)
if use_subtensor: if use_subtensor:
op = builtin.Subtensor(items=items) op = builtin.Subtensor(items=items)
else: else:


+ 2
- 5
imperative/python/megengine/core/tensor/megbrain_graph.py View File

@@ -17,6 +17,7 @@ import numpy as np


from .. import _imperative_rt from .. import _imperative_rt
from .._imperative_rt import GraphOptimizeOptions, SerializationFormat from .._imperative_rt import GraphOptimizeOptions, SerializationFormat
from .._imperative_rt.core2 import apply
from .._wrap import as_device from .._wrap import as_device
from ..ops.builtin import OpDef from ..ops.builtin import OpDef


@@ -126,9 +127,8 @@ class Graph(_imperative_rt.ComputingGraph):




class VarNode: class VarNode:
def __init__(self, node: _imperative_rt.VarNode, isscalar=False):
def __init__(self, node: _imperative_rt.VarNode):
self._node = node self._node = node
self._isscalar = isscalar
if hasattr(self.graph, "_var_cache"): if hasattr(self.graph, "_var_cache"):
self.graph._var_cache[node] = self self.graph._var_cache[node] = self


@@ -530,9 +530,6 @@ def _unwrap(x):




def apply_normal_varnode(op: OpDef, *args: VarNode): def apply_normal_varnode(op: OpDef, *args: VarNode):
# for PyOp like RemoteSend/Recv
if getattr(op, "op", None):
op = op.op
outputs = _imperative_rt.invoke_op(op, _unwrap(args)) outputs = _imperative_rt.invoke_op(op, _unwrap(args))
return _wrap(outputs) return _wrap(outputs)




+ 1
- 10
imperative/python/megengine/core/tensor/utils.py View File

@@ -51,10 +51,7 @@ def concatenate(inputs, axis=0, *, device=None):
def astype(x, dtype): def astype(x, dtype):
dtype = np.dtype(dtype) dtype = np.dtype(dtype)
if not is_dtype_equal(x.dtype, dtype): if not is_dtype_equal(x.dtype, dtype):
isscalar = x._isscalar()
(x,) = apply(builtin.TypeCvt(dtype=dtype), x) (x,) = apply(builtin.TypeCvt(dtype=dtype), x)
if isscalar:
x._setscalar()
return x return x




@@ -129,13 +126,6 @@ def isscalar(x):
return np.isscalar(x) return np.isscalar(x)




def setscalar(x):
if isinstance(x, (Tensor, SymbolVar)):
x._setscalar()
else:
raise NotImplementedError("Unsupport type {}".format(type(x)))


def astensor1d(x, *reference, dtype=None, device=None): def astensor1d(x, *reference, dtype=None, device=None):
"""Convert something to 1D tensor. Support following types """Convert something to 1D tensor. Support following types


@@ -237,6 +227,7 @@ for name, mode in [
("**", "pow"), ("**", "pow"),
("max", "max"), ("max", "max"),
("additive", "add"), ("additive", "add"),
("exp", "EXP"),
]: ]:
_opr_map[(name, 2)] = builtin.Elemwise(mode=mode) _opr_map[(name, 2)] = builtin.Elemwise(mode=mode)




+ 8
- 28
imperative/python/megengine/distributed/functional.py View File

@@ -13,7 +13,7 @@ import numpy as np
from ..core._imperative_rt.core2 import apply from ..core._imperative_rt.core2 import apply
from ..core.autodiff.grad import Function, _grad_manager_dict from ..core.autodiff.grad import Function, _grad_manager_dict
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
from ..core.tensor.utils import isscalar, setscalar
from ..core.tensor.utils import isscalar
from ..device import get_default_device, what_is_xpu from ..device import get_default_device, what_is_xpu
from ..tensor import Tensor from ..tensor import Tensor
from . import group from . import group
@@ -72,15 +72,6 @@ def collective_comm(inp, mode, group, device):
) )
(result,) = apply(op, inp) (result,) = apply(op, inp)
# assume all workers have homogeneous shape # assume all workers have homogeneous shape
if mode in (
CollectiveComm.Mode.REDUCE_SUM,
CollectiveComm.Mode.BROADCAST,
CollectiveComm.Mode.ALL_REDUCE_SUM,
CollectiveComm.Mode.ALL_REDUCE_MAX,
CollectiveComm.Mode.ALL_REDUCE_MIN,
):
if isscalar(inp):
setscalar(result)
return result return result




@@ -190,8 +181,7 @@ def reduce_sum(
# Rank 0 # output: None # Rank 0 # output: None
# Rank 1 # output: Tensor([1]) # Rank 1 # output: Tensor([1])
""" """
op = _ReduceSum(group, device)
(out,) = apply(op, inp)
out = _ReduceSum(group, device)(inp)


if group.rank == 0: if group.rank == 0:
return out return out
@@ -258,8 +248,7 @@ def broadcast(


_bcast_tracer_state(group, inp) _bcast_tracer_state(group, inp)


op = _Broadcast(group, device)
(out,) = apply(op, inp)
out = _Broadcast(group, device)(inp)
return out return out




@@ -604,8 +593,7 @@ def gather(
inp.shape inp.shape
) )


op = _Gather(group, device)
(out,) = apply(op, inp)
out = _Gather(group, device)(inp)


if group.rank == 0: if group.rank == 0:
if axis == 0: if axis == 0:
@@ -708,8 +696,7 @@ def scatter(
+ [_ for _ in range(axis + 1, inp.ndim + 1)] + [_ for _ in range(axis + 1, inp.ndim + 1)]
) )
inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape) inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape)
op = _Scatter(group, device)
(out,) = apply(op, inp)
out = _Scatter(group, device)(inp)
return out return out




@@ -832,7 +819,7 @@ class _RemoteRecv(Function):
self.op = op self.op = op


def forward(self, dummy): def forward(self, dummy):
return apply(self.op, dummy)
return apply(self.op, dummy)[0]


def backward(self, grad): def backward(self, grad):
get_client().bcast_val(grad is not None, self.op.key, 2) get_client().bcast_val(grad is not None, self.op.key, 2)
@@ -871,7 +858,7 @@ def remote_send(inp: Tensor, dest_rank: int):
op.addr, op.port = get_mm_server_addr() op.addr, op.port = get_mm_server_addr()
op.rank_to = dest_rank op.rank_to = dest_rank
op.backend = _backend() op.backend = _backend()
(out,) = apply(_RemoteSend(op), inp)
out = _RemoteSend(op)(inp)


_save_output_for_autodiff(inp, out) _save_output_for_autodiff(inp, out)


@@ -912,11 +899,6 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor
inp = Tensor(0, device=device) inp = Tensor(0, device=device)
_bcast_tracer_state(group, inp) _bcast_tracer_state(group, inp)


_isscalar = False
if len(shape) == 0:
shape = (1,)
_isscalar = True

op = RemoteRecv() op = RemoteRecv()
op.key = group.key op.key = group.key
op.cn = device op.cn = device
@@ -926,7 +908,5 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor
op.rank_from = src_rank op.rank_from = src_rank
op.backend = _backend() op.backend = _backend()


(ret,) = apply(_RemoteRecv(op), inp)
if _isscalar:
setscalar(ret)
ret = _RemoteRecv(op)(inp)
return ret return ret

+ 0
- 3
imperative/python/megengine/distributed/helper.py View File

@@ -67,9 +67,6 @@ def param_pack_split(inp: Tensor, offsets: list, shapes: list):
op.offsets = offsets op.offsets = offsets
op.shapes = [s or (1,) for s in shapes] op.shapes = [s or (1,) for s in shapes]
outputs = apply(op, inp) outputs = apply(op, inp)
for s, x in zip(shapes, outputs):
if not s:
x._setscalar()
return outputs return outputs






+ 0
- 25
imperative/python/megengine/experimental/autograd.py View File

@@ -1,25 +0,0 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from ..core._imperative_rt.core2 import (
set_allow_higher_order_directive as _set_allow_higher_order_directive,
)

__all__ = [
"enable_higher_order_directive",
"disable_higher_order_directive",
]


def enable_higher_order_directive():
_set_allow_higher_order_directive(True)


def disable_higher_order_directive():
_set_allow_higher_order_directive(False)

+ 0
- 3
imperative/python/megengine/functional/inplace.py View File

@@ -12,8 +12,5 @@ from ..core.ops.builtin import InplaceAdd




def _inplace_add_(dest, delta, alpha, beta): def _inplace_add_(dest, delta, alpha, beta):
isscalar = dest._isscalar()
dest._reset(apply(InplaceAdd(), dest, delta, alpha, beta)[0]) dest._reset(apply(InplaceAdd(), dest, delta, alpha, beta)[0])
if isscalar:
dest._setscalar()
return dest return dest

+ 1
- 3
imperative/python/megengine/functional/math.py View File

@@ -19,7 +19,7 @@ from ..core.ops import builtin
from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import amp from ..core.tensor import amp
from ..core.tensor.utils import _normalize_axis, cast_tensors, setscalar, subgraph
from ..core.tensor.utils import _normalize_axis, cast_tensors, subgraph
from ..jit import exclude_from_trace from ..jit import exclude_from_trace
from ..tensor import Tensor from ..tensor import Tensor
from ..utils.deprecation import deprecated_kwargs_default from ..utils.deprecation import deprecated_kwargs_default
@@ -1149,7 +1149,6 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor:
inp1.ndim <= 1 and inp2.ndim <= 1 inp1.ndim <= 1 and inp2.ndim <= 1
), "Input tensors for dot must be 1-dimensional or scalar" ), "Input tensors for dot must be 1-dimensional or scalar"
(result,) = apply(op, inp1, inp2) (result,) = apply(op, inp1, inp2)
setscalar(result)
return result return result




@@ -1200,5 +1199,4 @@ def _check_non_finite(inps: Iterable[Tensor], scale=1.0) -> Tensor:
for i in range(len(inps)): for i in range(len(inps)):
inps[i]._reset(oups[i]) inps[i]._reset(oups[i])


out._setscalar()
return out return out

+ 0
- 1
imperative/python/megengine/functional/nn.py View File

@@ -35,7 +35,6 @@ from ..core.tensor.utils import (
cast_tensors, cast_tensors,
convert_single_value, convert_single_value,
make_shape_tuple, make_shape_tuple,
setscalar,
subgraph, subgraph,
) )
from ..device import get_default_device from ..device import get_default_device


+ 0
- 9
imperative/python/megengine/functional/tensor.py View File

@@ -972,13 +972,6 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
) )
axis = sorted(axis) axis = sorted(axis)
assert axis, "axis could not be empty" assert axis, "axis could not be empty"
if inp._isscalar():
assert axis[0] == 0, "invalid axis {} for ndim 0".format(axis[0])
if len(axis) == 1:
inp = copy(inp, device=None)
inp._unsetscalar()
return inp
axis = axis[1:]
op = builtin.AddAxis(axis=axis) op = builtin.AddAxis(axis=axis)
(result,) = apply(op, inp) (result,) = apply(op, inp)
return result return result
@@ -1164,8 +1157,6 @@ def repeat(inp: Tensor, repeats: int, axis: Optional[int] = None):
if axis is None: if axis is None:
inp = inp.reshape(-1) # flatten inp = inp.reshape(-1) # flatten
axis = 0 axis = 0
if inp._isscalar():
inp._unsetscalar()
shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device) shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device)
# assume inp.ndim is not changed during trace # assume inp.ndim is not changed during trace
max_axis = len(shape) - 1 max_axis = len(shape) - 1


+ 1
- 13
imperative/python/megengine/jit/__init__.py View File

@@ -6,19 +6,7 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ..core._imperative_rt.core2 import (
set_cpp_apply_const_with_tracing,
set_cpp_apply_with_tracing,
)
from .dtr_config import DTRConfig from .dtr_config import DTRConfig
from .graph_opt_config import GraphOptimizationConfig from .graph_opt_config import GraphOptimizationConfig
from .sublinear_memory_config import SublinearMemoryConfig from .sublinear_memory_config import SublinearMemoryConfig
from .tracing import (
apply_const_with_tracing,
apply_with_tracing,
exclude_from_trace,
trace,
)

set_cpp_apply_with_tracing(apply_with_tracing)
set_cpp_apply_const_with_tracing(apply_const_with_tracing)
from .tracing import TraceError, exclude_from_trace, trace

+ 161
- 962
imperative/python/megengine/jit/tracing.py
File diff suppressed because it is too large
View File


+ 35
- 6
imperative/python/megengine/module/module.py View File

@@ -111,6 +111,7 @@ class Module(metaclass=ABCMeta):


# used for profiler and automatic naming # used for profiler and automatic naming
self._name = None self._name = None
self._short_name = None


@abstractmethod @abstractmethod
def forward(self, inputs): def forward(self, inputs):
@@ -137,7 +138,7 @@ class Module(metaclass=ABCMeta):
return HookHandler(self._forward_hooks, hook) return HookHandler(self._forward_hooks, hook)


def __call__(self, *inputs, **kwargs): def __call__(self, *inputs, **kwargs):
AutoNaming.push_scope(self.name if self.name is not None else self._name)
AutoNaming.push_scope(self.name if self.name is not None else self._short_name)
for hook in self._forward_pre_hooks.values(): for hook in self._forward_pre_hooks.values():
modified_inputs = hook(self, inputs) modified_inputs = hook(self, inputs)
if modified_inputs is not None: if modified_inputs is not None:
@@ -641,15 +642,43 @@ class Module(metaclass=ABCMeta):
else: else:
if modules is not None and name in modules: if modules is not None and name in modules:
modules.remove(name) modules.remove(name)
for k, v in _expand_structure(name, value):
if not v._name:
v._name = k
elif v._name != k:

def append_name(prefix, name):
if prefix is None or prefix == "":
return name
return prefix + "." + name

def set_name(parent, prefix, name, obj):
if isinstance(obj, Tensor):
assert obj.name is not None
if obj.name != "":
name = obj.name
full_name = append_name(prefix, name)
if obj._short_name and obj._short_name != name:
logger.warning( logger.warning(
"try setting the submodule `{}` to `{}`'s new attribute `{}`, its name `{}` will remain unchanged".format( "try setting the submodule `{}` to `{}`'s new attribute `{}`, its name `{}` will remain unchanged".format(
type(v), type(self), k, v._name
obj._short_name, type(parent), name, obj._short_name
) )
) )
return
if isinstance(obj, Tensor):
obj._prefix = prefix
obj._name = full_name
obj._short_name = name
obj._set_name(obj._name)
return obj._name
elif isinstance(obj, Module):
obj._name = full_name
obj._short_name = name
for k, v in obj._flatten(recursive=False, with_key=True):
set_name(obj, full_name, k, v)
return obj._name
else:
assert False

for k, v in _expand_structure(name, value):
prefix = self._name if self._name else self.name
set_name(self, prefix, k, v)
super().__setattr__(name, value) super().__setattr__(name, value)


def __delattr__(self, name: str): def __delattr__(self, name: str):


+ 5
- 0
imperative/python/megengine/random/rng.py View File

@@ -14,6 +14,7 @@ from numpy.random import MT19937


from .. import Tensor from .. import Tensor
from ..core._imperative_rt.core2 import apply from ..core._imperative_rt.core2 import apply
from ..core._imperative_rt.core2 import sync as _sync
from ..core._imperative_rt.ops import delete_rng_handle as _delete_rng_handle from ..core._imperative_rt.ops import delete_rng_handle as _delete_rng_handle
from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed
from ..core._imperative_rt.ops import ( from ..core._imperative_rt.ops import (
@@ -650,6 +651,10 @@ class RNG:


def __del__(self): def __del__(self):
if self._handle != 0: if self._handle != 0:
# RNG op might execute after handle released due to async dispatch, so
# we need sync before delete a handle to avoid memory leak or
# use-after-free
_sync()
_delete_rng_handle(self._handle) _delete_rng_handle(self._handle)






+ 20
- 7
imperative/python/megengine/tensor.py View File

@@ -12,7 +12,7 @@ import numpy as np


from .core._imperative_rt import CompNode from .core._imperative_rt import CompNode
from .core._imperative_rt.core2 import Tensor as _Tensor from .core._imperative_rt.core2 import Tensor as _Tensor
from .core._imperative_rt.core2 import apply
from .core._imperative_rt.core2 import apply, set_py_tensor_type
from .core._trace_option import use_symbolic_shape from .core._trace_option import use_symbolic_shape
from .core._wrap import as_device from .core._wrap import as_device
from .core.ops.builtin import Copy, GetVarShape from .core.ops.builtin import Copy, GetVarShape
@@ -20,7 +20,6 @@ from .core.tensor.array_method import ArrayMethodMixin
from .device import _valid_device, get_default_device from .device import _valid_device, get_default_device
from .logger import get_logger from .logger import get_logger
from .utils.deprecation import deprecated from .utils.deprecation import deprecated
from .utils.naming import AutoNaming


logger = get_logger(__name__) logger = get_logger(__name__)


@@ -40,6 +39,10 @@ class Tensor(_Tensor, ArrayMethodMixin):
grad = None grad = None
dmap_callback = None dmap_callback = None
_qparams = None _qparams = None
_custom_name = ""
_name = None
_short_name = None
_prefix = None


def __new__( def __new__(
cls, cls,
@@ -81,9 +84,15 @@ class Tensor(_Tensor, ArrayMethodMixin):
device: str = None, device: str = None,
is_const: bool = False, is_const: bool = False,
no_cache: bool = False, no_cache: bool = False,
name: str = None,
name: str = "",
): ):
pass
if name is None:
name = ""
self._custom_name = name
self._name = name
self._short_name = name
self._set_name(self._name)
self._prefix = None


@property @property
def shape(self) -> Union[tuple, "Tensor"]: def shape(self) -> Union[tuple, "Tensor"]:
@@ -151,12 +160,13 @@ class Tensor(_Tensor, ArrayMethodMixin):


@property @property
def name(self): def name(self):
return self.c_name
return self._custom_name


@name.setter @name.setter
def name(self, name): def name(self, name):
self.c_name = name
AutoNaming.record_var_name(self._mixin_handle, name)
self._custom_name = name
self._name = self._prefix + "." + name if self._prefix else name
self._set_name(self._name)


@deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0") @deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0")
def set_value(self, value): def set_value(self, value):
@@ -224,6 +234,9 @@ class Tensor(_Tensor, ArrayMethodMixin):
self._qparams = qparams self._qparams = qparams




set_py_tensor_type(Tensor)


tensor = Tensor tensor = Tensor






+ 0
- 3
imperative/python/megengine/traced_module/__init__.py View File

@@ -6,7 +6,6 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.


from ..core._imperative_rt.core2 import set_cpp_apply_module_trace
from . import compat from . import compat
from ._passes import optimize from ._passes import optimize
from .pytree import register_supported_type from .pytree import register_supported_type
@@ -14,14 +13,12 @@ from .tm_config import disable_default_checker, enable_expr_checker
from .traced_module import ( from .traced_module import (
TracedModule, TracedModule,
_register_all_builtin_module, _register_all_builtin_module,
cpp_apply_module_trace,
register_as_builtin, register_as_builtin,
trace_module, trace_module,
wrap, wrap,
) )


_register_all_builtin_module() _register_all_builtin_module()
set_cpp_apply_module_trace(cpp_apply_module_trace)


__all__ = [ __all__ = [
"register_as_builtin", "register_as_builtin",


+ 0
- 3
imperative/python/megengine/traced_module/checker.py View File

@@ -13,7 +13,6 @@ import numpy as np
from ..core._imperative_rt.core2 import apply from ..core._imperative_rt.core2 import apply
from ..core._imperative_rt.ops import ROIAlign, ROIPooling from ..core._imperative_rt.ops import ROIAlign, ROIPooling
from ..core.ops.builtin import Copy from ..core.ops.builtin import Copy
from ..core.tensor.utils import isscalar, setscalar
from ..tensor import Tensor from ..tensor import Tensor
from .tm_config import _exclude_from_trace from .tm_config import _exclude_from_trace


@@ -70,8 +69,6 @@ class TracedModuleChecker:
self.current_node2values()[node] = apply( self.current_node2values()[node] = apply(
Copy(comp_node=value.device), value Copy(comp_node=value.device), value
)[0] )[0]
if isscalar(value):
setscalar(self.current_node2values()[node])


def check_apply_special_cases(self, opdef, num_outputs): def check_apply_special_cases(self, opdef, num_outputs):
indexs = list(range(num_outputs)) indexs = list(range(num_outputs))


+ 11
- 2
imperative/python/megengine/traced_module/expr.py View File

@@ -20,6 +20,7 @@ from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import ( from ..core._imperative_rt.core2 import (
apply, apply,
is_tracing_module, is_tracing_module,
set_module_trace_hook,
set_module_tracing, set_module_tracing,
unset_module_tracing, unset_module_tracing,
) )
@@ -605,8 +606,7 @@ class Apply(Expr):
def apply_module_trace_hook(cls, opdef, *inputs): def apply_module_trace_hook(cls, opdef, *inputs):
for i in inputs: for i in inputs:
node = NodeMixin.get(i, None) node = NodeMixin.get(i, None)
if node is None: # capture as constant
NodeMixin.wrap_safe(i, Constant.make(i))
assert node is not None


if isinstance(opdef, FakeQuant): if isinstance(opdef, FakeQuant):
inp_nodes = [NodeMixin.get(inputs[0])] inp_nodes = [NodeMixin.get(inputs[0])]
@@ -805,3 +805,12 @@ class Constant(Expr):
if isinstance(v, _ModuleState): if isinstance(v, _ModuleState):
state[k] = v.to_module() state[k] = v.to_module()
self.__dict__.update(state) self.__dict__.update(state)


def _module_trace_capture(value):
node = Constant.make(value)
NodeMixin.wrap_safe(value, node)
return node


set_module_trace_hook(Apply.apply_module_trace_hook)

+ 0
- 2
imperative/python/megengine/traced_module/module_tracer.py View File

@@ -101,9 +101,7 @@ BUILTIN_TENSOR_WRAP_METHOD = [
"requires_grad", "requires_grad",
"_reset", "_reset",
"_isscalar", "_isscalar",
"_setscalar",
"_tuple_shape", "_tuple_shape",
"_unsetscalar",
] ]






+ 2
- 6
imperative/python/megengine/traced_module/traced_module.py View File

@@ -43,7 +43,6 @@ from ..core._imperative_rt.core2 import (
) )
from ..core._trace_option import set_symbolic_shape from ..core._trace_option import set_symbolic_shape
from ..core.ops.builtin import Copy from ..core.ops.builtin import Copy
from ..core.tensor.utils import isscalar, setscalar
from ..module import Module from ..module import Module
from ..module import external as MExternal from ..module import external as MExternal
from ..module.qat import QATModule from ..module.qat import QATModule
@@ -1295,12 +1294,9 @@ def _wrapped_function(orig_func):
return orig_func(*args, **kwargs) return orig_func(*args, **kwargs)
if isinstance(args[1], RawTensor): if isinstance(args[1], RawTensor):
node = NodeMixin.get(inputs[1]) node = NodeMixin.get(inputs[1])
is_scalar = isscalar(inputs[1])
inputs[1] = apply( inputs[1] = apply(
Copy(comp_node=inputs[1].device), Tensor(inputs[1]) Copy(comp_node=inputs[1].device), Tensor(inputs[1])
)[0] )[0]
if is_scalar:
setscalar(inputs[1])
# copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, # copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor,
# which will cause they have same _NodeMixin__node in tracing. # which will cause they have same _NodeMixin__node in tracing.
NodeMixin.wrap_safe(inputs[1], node) NodeMixin.wrap_safe(inputs[1], node)
@@ -2468,8 +2464,8 @@ def trace_module(
try: try:
net_name = mod._name if mod._name else mod.__class__.__name__ net_name = mod._name if mod._name else mod.__class__.__name__
use_sym_shape = set_symbolic_shape(True) use_sym_shape = set_symbolic_shape(True)
set_module_tracing()
set_active_module_tracer(module_tracer(_wrapped_function)) set_active_module_tracer(module_tracer(_wrapped_function))
set_module_tracing()
for cls in [Expr, Node]: for cls in [Expr, Node]:
cls._set_next_id(0) cls._set_next_id(0)
with active_module_tracer().patcher: with active_module_tracer().patcher:
@@ -2518,9 +2514,9 @@ def trace_module(
return traced_mod return traced_mod
finally: finally:
set_symbolic_shape(use_sym_shape) set_symbolic_shape(use_sym_shape)
set_active_module_tracer(None)
unset_module_tracing() unset_module_tracing()
for t in mod.tensors(recursive=True): for t in mod.tensors(recursive=True):
NodeMixin.clear_node(t) NodeMixin.clear_node(t)
for t in inputs: for t in inputs:
NodeMixin.clear_node(t) NodeMixin.clear_node(t)
set_active_module_tracer(None)

+ 5
- 0
imperative/python/megengine/utils/profiler.py View File

@@ -137,6 +137,11 @@ class Profiler(ContextDecorator):
get_logger().info("process {} generating {}".format(self._pid, format)) get_logger().info("process {} generating {}".format(self._pid, format))
self._dump_callback(path, format) self._dump_callback(path, format)
get_logger().info("profiling results written to {}".format(path)) get_logger().info("profiling results written to {}".format(path))
if os.path.getsize(path) > 64 * 1024 * 1024:
get_logger().warning(
"profiling results too large, maybe you are profiling multi iters,"
"consider attach profiler in each iter separately"
)
self._dump_callback = None self._dump_callback = None
_living_profilers.remove(self) _living_profilers.remove(self)




+ 59
- 625
imperative/python/src/grad.cpp View File

@@ -9,9 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */


#pragma GCC diagnostic ignored "-Wmissing-field-initializers"

#include "./grad.h" #include "./grad.h"

#include "megbrain/imperative/backward_graph_opt.h" #include "megbrain/imperative/backward_graph_opt.h"
#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/proxy_graph_detail.h" #include "megbrain/imperative/proxy_graph_detail.h"
@@ -19,465 +18,19 @@


#include "range/v3/all.hpp" #include "range/v3/all.hpp"


#include "./transformation.h"

namespace py = pybind11; namespace py = pybind11;
namespace views = ranges::views; namespace views = ranges::views;


namespace mgb::imperative::python { namespace mgb::imperative::python {


using scoped_disable = ApplyContext::scoped_disable;
using Flags = Tensor::Flags;

namespace { namespace {

struct GradSlotWeakPtr {
std::weak_ptr<GradFn> grad_fn;
size_t idx;
};

std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph(
ApplyContext& ctx, const apply_result_t& outputs) {
// hash
using OptimizedBackwardGraphCache = OpMethResultCache<
std::shared_ptr<OptimizedBackwardGraphResult>, SmallVector<bool>>;
thread_local OptimizedBackwardGraphCache cache;
decltype(cache)::key_t cache_key{ctx.op};
SmallVector<LogicalTensorDesc>& input_descs = cache_key.inputs;
SmallVector<bool>& input_requires_grad = std::get<0>(cache_key.extras);
input_descs.resize(ctx.nargs);
input_requires_grad.resize(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; ++i) {
input_descs[i].layout.dtype = ctx.args[i]->dtype();
input_descs[i].comp_node = ctx.args[i]->comp_node();
input_requires_grad[i] = python::input_requires_grad(ctx, i);
}

auto iter = cache.find(cache_key);
if (iter != cache.end()) {
return iter->second;
}

// slow path
SmallVector<bool> output_has_grad(outputs.size(), true);
std::shared_ptr<OptimizedBackwardGraphResult> ret;
auto bg = OpDef::make_backward_graph(
*ctx.op, input_descs, input_requires_grad, output_has_grad);
if (!bg.graph.empty()) {
ret = std::make_shared<OptimizedBackwardGraphResult>(bg);
}
cache.emplace(cache_key, ret);
return ret;
std::unordered_map<std::shared_ptr<GradKey>, GradKeyWrapper*> grad_key_map;
} }


struct BackwardGraphWithClosure {
std::shared_ptr<OptimizedBackwardGraphResult> backward_graph;
SmallVector<std::shared_ptr<Tensor>> closure;
size_t output_mask_offset;
size_t grad_mask_offset;

BackwardGraphWithClosure(
std::shared_ptr<OptimizedBackwardGraphResult> backward_graph_,
ApplyContext& ctx, const apply_result_t& outputs)
: backward_graph(backward_graph_),
output_mask_offset(ctx.nargs),
grad_mask_offset(ctx.nargs + outputs.size()) {
// save_for_backward[0:nargs]:
// whether input is kept for backward
//
// save_for_backward[nargs:nargs+outputs.size()]:
// whether output is kept for backward
//
// save_for_backward[-outputs.size():]:
// whether gradient of output can propagate to any input
//
// Example:
// perform c = a * b, with a.requires_grad == True and
// b.requires_grad == False, save_for_backward = [0, 1, 0, 1]
auto& save_for_backward = backward_graph->save_for_backward;
mgb_assert(save_for_backward.size() == ctx.nargs + 2 * outputs.size());
size_t count = std::count_if(
save_for_backward.begin(), save_for_backward.end(), ranges::identity{});
if (!backward_graph->precomp.empty()) {
auto&& irng = ranges::span(ctx.args, ctx.nargs);
auto&& orng = views::transform(outputs, [](auto&& i) { return i.get(); });
auto precomp = apply(backward_graph->precomp, views::concat(irng, orng));
closure.reserve(precomp.size() + count);
std::copy(precomp.begin(), precomp.end(), std::back_inserter(closure));
} else {
closure.reserve(count);
}
for (size_t i = 0; i < ctx.nargs; ++i) {
if (save_for_backward[i]) {
closure.push_back(ctx.args[i]->shared_from_this());
}
}
for (size_t i = 0; i < outputs.size(); ++i) {
if (save_for_backward[ctx.nargs + i]) {
closure.push_back(outputs[i]);
}
}
}

template <typename T, typename R>
void operator()(BackwardContext&, T&& grads, R&& receiver) {
Tensor* args[closure.size() + grads.size()];
size_t nargs = 0;
for (auto&& t : closure) {
args[nargs++] = t.get();
}
bool null_grad = false;
for (size_t i = 0; i < grads.size(); ++i) {
if (backward_graph->save_for_backward[grad_mask_offset + i]) {
if (grads[i]) {
if (null_grad) {
PyErr_SetString(PyExc_NotImplementedError, "report to devs");
throw py::error_already_set();
}
args[nargs++] = grads[i];
} else {
null_grad = true;
}
}
}
if (null_grad)
return;

auto igrads = apply(backward_graph->backward, args, nargs);
auto&& it = igrads.begin();
for (auto [i, p] : views::enumerate(backward_graph->input_has_grad)) {
if (p) {
receiver(i, std::move(*it));
++it;
}
}
}

bool input_has_grad(size_t i) { return backward_graph->input_has_grad[i]; }

bool output_requires_grad(size_t i) {
return backward_graph->save_for_backward[grad_mask_offset + i];
}

bool output_captured(size_t i) {
return backward_graph->save_for_backward[output_mask_offset + i];
}
};

struct PythonBackward {
py::object pyfunc;
size_t input_size;

PythonBackward(py::object f, size_t nin) : pyfunc(f), input_size(nin) {}

template <typename T, typename R>
void operator()(BackwardContext& ctx, T&& grads, R&& receiver) {
auto args = py::tuple(grads.size());
for (size_t i = 0; i < grads.size(); ++i) {
auto&& g = grads[i];
args[i] = g ? ctx.wrap_tensor(g) : py::none();
}
auto input_grads = py::reinterpret_steal<py::object>(
PyObject_Call(pyfunc.ptr(), args.ptr(), nullptr));
if (!input_grads)
throw py::error_already_set();
if (input_grads.is_none())
return;
if (auto* tw = TensorWrapper::try_cast(input_grads.ptr())) {
if (input_size != 1) {
throw py::value_error(
"custom grad rule returned wrong number of grads");
}
if (!ctx.pytype) {
ctx.pytype = Py_TYPE(input_grads.ptr());
}
receiver(0, tw->m_tensor);
return;
}
if (py::len(input_grads) != input_size) {
throw py::value_error("custom grad rule returned wrong number of grads");
}
for (auto [i, g] : views::enumerate(input_grads)) {
if (g.is_none())
continue;
auto* tw = TensorWrapper::try_cast(g.ptr());
if (!tw) {
throw py::type_error("custom grad rule returned non-tensor");
}
if (!ctx.pytype) {
ctx.pytype = Py_TYPE(g.ptr());
}
receiver(i, tw->m_tensor);
}
}

static constexpr bool input_has_grad(size_t) { return true; }
static constexpr bool output_requires_grad(size_t) { return true; }
static constexpr bool output_captured(size_t) { return true; }
};

} // namespace

struct GradProducerRecord : intrusive_list::Node<GradProducerRecord> {
using Base = intrusive_list::Node<GradProducerRecord>;

GradProducerRecord() = default;
GradProducerRecord(GradProducerRecord::head_t& head)
: Base(intrusive_list::after_t{}, head) {}
// GradProducerRecord(GradProducerRecord&&) = default;
// GradProducerRecord& operator=(GradProducerRecord&) = default;
// GradProducerRecord& operator=(GradProducerRecord&&) = default;
};

struct GradSlot {
std::shared_ptr<Tensor> grad;
py::object callback;
GradProducerRecord::head_t producer_head;
};

struct GradSlotProducerPtr : GradSlotPtr {
GradProducerRecord producer_record;

GradSlotProducerPtr() = default;
GradSlotProducerPtr(GradInfo& info)
: GradSlotPtr(info), producer_record(info->producer_head) {}
};

struct GradFn : std::enable_shared_from_this<GradFn> {
static MemPool<GradFn> pool;

std::weak_ptr<GradKey> key;
// slots for receiving and accumulating grads
// same length as outputs (of forward op)
SmallVector<GradSlot> slots;
// where to send and accumulate grads
// same length as inputs (of forward op)
SmallVector<GradSlotProducerPtr> dsts;
// encapsules actual function to compute gradient
std::variant<
std::monostate, BackwardGraphWithClosure, PythonBackward, CustomBackward>
backward;
// a flag used during backward
bool in_ref_keeper = false;

static void deleter(GradFn* ptr) { pool.free(ptr); }

static std::shared_ptr<GradFn> make() {
return std::shared_ptr<GradFn>(pool.alloc(), &deleter);
}

void clear() {
key.reset();
slots.clear();
dsts.clear();
backward.emplace<std::monostate>();
}
};

GradSlotPtr::operator bool() const {
return bool(grad_fn);
}

GradSlot* GradSlotPtr::operator->() {
return &grad_fn->slots[idx];
}

namespace {

class GradFnHelper {
std::shared_ptr<GradFn> grad_fn;

GradFn* get() {
if (!grad_fn) {
grad_fn = std::make_shared<GradFn>();
}
return grad_fn.get();
}

friend apply_result_t imperative::python::apply_grad(ApplyContext&);

public:
template <typename T, typename... Args>
auto& emplace(Args&&... args) {
return get()->backward.emplace<T>(std::forward<Args>(args)...);
}

void reset() { grad_fn = nullptr; }
};

apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
// copy inputs first, or trace will make InputNodes for each usage
ApplyContext ctx_dup = ctx;
SmallVector<std::shared_ptr<Tensor>> inputs_copy;
SmallVector<Tensor*> inputs_copy_weak;
for (size_t i = 0; i < ctx.nargs; ++i) {
Tensor* input = ctx.args[i];
inputs_copy.push_back(python::apply(FastpathCopy::make(), input)[0]);
inputs_copy_weak.push_back(inputs_copy.back().get());
inputs_copy.back()->m_grad_info_dict = ctx.args[i]->m_grad_info_dict;
if (input->m_flags & Flags::GRAD) {
inputs_copy.back()->m_flags |= Flags::GRAD;
}
}
ctx_dup.args = inputs_copy_weak.data();

auto outputs = apply(ctx_dup);

auto backward_graph = make_backward_graph(ctx_dup, outputs);
if (!backward_graph) {
return outputs;
}
ret_grad_fn.emplace<BackwardGraphWithClosure>(
std::move(backward_graph), ctx_dup, outputs);

return outputs;
}

apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
auto* op = ctx.op->try_cast_final<GenericPyOp>();
py::tuple pyin(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; ++i) {
pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this());
}
auto grad_rule = py::getattr(op->obj, "_grad_rule");
auto pyret = py::reinterpret_steal<py::object>(
PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr));
if (!pyret)
throw py::error_already_set();
auto [outputs, backward] = py::cast<std::tuple<py::object, py::function>>(pyret);
ret_grad_fn.emplace<PythonBackward>(std::move(backward), ctx.nargs);
if (auto* tw = TensorWrapper::try_cast(outputs.ptr())) {
return {tw->m_tensor};
}
apply_result_t ret;
ret.reserve(py::len(outputs));
for (auto&& i : outputs) {
auto* tw = TensorWrapper::try_cast(i.ptr());
mgb_assert(tw);
ret.push_back(tw->m_tensor);
}
return ret;
}

} // namespace

apply_result_t apply_grad(ApplyContext& ctx) {
std::unordered_set<std::shared_ptr<GradKey>> grad_keys;
for (size_t i = 0; i < ctx.nargs; ++i) {
auto* tensor = ctx.args[i];
if (!tensor->m_grad_info_dict.empty()) {
size_t grad_cnt = 0;
for (auto&& grad_info : tensor->m_grad_info_dict) {
auto input_grad_key = grad_info.grad_fn->key.lock();
if (input_grad_key && input_grad_key->active &&
!input_grad_key->is_blocked()) {
grad_keys.insert(input_grad_key);
grad_cnt++;
}
}
if (!grad_cnt) {
tensor->m_flags &= ~Flags::GRAD;
}
} else {
tensor->m_flags &= ~Flags::GRAD;
}
}

ctx.flags &= ~Flags::GRAD;

if (grad_keys.empty()) {
return apply(ctx);
} else if (grad_keys.size() > 1 && !GradKey::allow_higher_order_directive) {
PyErr_SetString(
PyExc_NotImplementedError,
"second order directive not enabled, please call "
"'megengine.experimental.enable_higher_order_directive'");
throw pyext17::py_err_set();
}

GradFnHelper grad_fn_holder;
auto outputs = [&]() {
auto _ = scoped_disable(Flags::GRAD);
if (ctx.op->same_type<GenericPyOp>()) {
return python_grad_rule(ctx, grad_fn_holder);
}
auto&& registry = grad_rule_registry();
auto&& it = registry.find(ctx.op->dyn_typeinfo());
if (it != registry.end()) {
auto&& maker = grad_fn_holder.emplace<CustomBackward>().maker(ctx);
if (auto ret = it->second(ctx, maker)) {
maker.finalize();
return *ret;
}
grad_fn_holder.reset();
}
return backward_graph_grad_rule(ctx, grad_fn_holder);
}();

if (!grad_fn_holder.grad_fn) {
return outputs;
}

for (auto&& grad_key : grad_keys) {
auto grad_fn = std::make_shared<GradFn>();
grad_fn->backward = grad_fn_holder.grad_fn->backward;
grad_fn->key = grad_key;
grad_fn->slots.resize(outputs.size());
grad_fn->dsts.reserve(ctx.nargs);

std::visit(
[&](auto& backward) {
using T = std::decay_t<decltype(backward)>;
if constexpr (std::is_same_v<T, std::monostate>) {
mgb_assert(0);
} else {
for (size_t i = 0; i < ctx.nargs; ++i) {
if (backward.input_has_grad(i) &&
input_requires_grad(ctx, i) &&
ctx.args[i]->m_grad_info_dict.count(grad_key.get())) {
auto& input_grad_info =
ctx.args[i]->m_grad_info_dict.at(
grad_key.get());
grad_fn->dsts.emplace_back(input_grad_info);
// register as grad producer
grad_fn->dsts.back().producer_record.insert_after(
input_grad_info->producer_head);
} else {
grad_fn->dsts.emplace_back();
}
}
for (size_t i = 0; i < outputs.size(); ++i) {
if (backward.output_requires_grad(i)) {
if (backward.output_captured(i)) {
// avoid reference cycle [Tensor <-> GradFn]
static std::shared_ptr<OpDef> op =
std::make_shared<FastpathCopy>();
outputs[i] = python::apply(op, outputs[i])[0];
}
// populate grad info of output tensor
auto& grad_info =
outputs[i]->m_grad_info_dict[grad_key.get()];
grad_info.grad_fn = grad_fn;
grad_info.idx = i;
grad_info.insert_after(grad_key->free_vars_head);
outputs[i]->m_flags |= Flags::GRAD;
}
}
}
},
grad_fn->backward);

// record forward history
grad_key->tape.emplace_back(grad_fn);
}

return outputs;
}

PyObject* GradKeyWrapper::get_priority() {
return py::cast(m_key->priority).release().ptr();
}

void GradKeyWrapper::set_priority(pybind11::handle priority) {
m_key->priority = py::cast<int>(priority);
GradKeyWrapper::GradKeyWrapper() : m_key(std::make_shared<GradKey>()) {
grad_key_map[m_key] = this;
} }


void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) { void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) {
@@ -488,157 +41,59 @@ void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) {
if (!tw) { if (!tw) {
throw py::type_error("argument 1 must be Tensor"); throw py::type_error("argument 1 must be Tensor");
} }
auto* tensor = tw->m_tensor.get();
py::object callback; py::object callback;
if (args[1] != Py_None) { if (args[1] != Py_None) {
callback = py::reinterpret_borrow<py::object>(args[1]); callback = py::reinterpret_borrow<py::object>(args[1]);
} }
m_key->attach(tensor, std::move(callback));
}

//! GradKey is weakly refered by tensor->m_grad_info.grad_fn->key after attach
void GradKey::attach(Tensor* tensor, pybind11::object callback) {
if (!active) {
throw py::value_error("grad key finalized");
}

if (tensor->m_grad_info_dict.count(this)) {
if (tensor->m_grad_info_dict.at(this)->callback) {
throw py::value_error("callback already set on this tensor");
GenericFunction generic_callback =
[=](Span<ValueRef> inputs) -> std::vector<ValueRef> {
mgb_assert(inputs.size() == 1);
if (callback) {
callback(TensorWrapper::make(py_tensor_type, inputs[0]));
} }
} else {
auto& grad_info = tensor->m_grad_info_dict[this];
grad_info.idx = 0;
auto& grad_fn = grad_info.grad_fn;
grad_fn = std::make_shared<GradFn>();
grad_fn->key = shared_from_this();
grad_fn->slots.resize(1);
grad_info.insert_after(free_vars_head);
tensor->m_flags |= Flags::GRAD;
}
tensor->m_grad_info_dict.at(this).grad_fn->slots[0].callback = std::move(callback);
}

template <typename T>
void accum_grad(std::shared_ptr<Tensor>& grad, T&& delta) {
if (!grad) {
grad = std::forward<T>(delta);
return;
}
static std::shared_ptr<OpDef> op =
std::shared_ptr<OpDef>(new Elemwise(Elemwise::Mode::ADD));
grad = apply(op, grad, std::forward<T>(delta))[0];
return {};
};
tw->m_tensor->reset(imperative::apply(
AttachGrad(m_key), tw->m_tensor->data(),
FunctionValue::make(generic_callback))[0]);
} }


void GradKey::backward(
std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) {
if (!active) {
throw py::value_error("finalized");
void GradKeyWrapper::backward(GradKeyWrapper* self, py::list tensors, py::list grads) {
std::vector<ValueRef> args;
mgb_assert(tensors.size() == grads.size());
for (auto&& tensor : tensors) {
args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data());
} }
if (tensors.size() != grads.size()) {
throw py::value_error("tensor and grad size mismatch");
for (auto&& grad : grads) {
args.push_back(TensorWrapper::try_cast(grad.ptr())->m_tensor->data());
} }

// this GradKey is marked inactive here
active = false;
struct CleanupGuard {
GradKey* owner;
size_t priority_backup;
CleanupGuard(GradKey* this_) : owner(this_) {
priority_backup = sm_min_priority;
sm_min_priority = owner->priority + 1;
}
~CleanupGuard() {
owner->cleanup();
sm_min_priority = priority_backup;
}
} _cleanup_guard(this);

if (tape.empty())
return;

BackwardContext bctx;
if (!grads.empty()) {
bctx.pytype = Py_TYPE(grads[0]->self().ptr());
}

for (size_t i = 0; i < tensors.size(); ++i) {
if (tensors[i]->m_tensor->m_grad_info_dict.count(this) == 0) {
continue;
}
auto& grad_info = tensors[i]->m_tensor->m_grad_info_dict.at(this);
grad_info->grad = grads[i]->m_tensor;
}

std::vector<std::shared_ptr<GradFn>> ref_keeper;
ref_keeper.reserve(tape.size());

// back-propagation in reverse order
for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) {
auto&& grad_fn = tape[k].lock();
if (!grad_fn)
continue;

auto grad_receiver = [&](size_t i, auto&& g) {
auto& dst = grad_fn->dsts[i];
if (dst) {
accum_grad(dst->grad, std::forward<decltype(g)>(g));
}
};
std::visit(
[&](auto&& backward) {
using T = std::decay_t<decltype(backward)>;
if constexpr (std::is_same_v<T, std::monostate>) {
mgb_assert(0);
} else {
auto&& grads = views::transform(
grad_fn->slots,
[](auto&& slot) { return slot.grad.get(); });
backward(
bctx, std::forward<decltype(grads)>(grads),
grad_receiver);
}
},
grad_fn->backward);

for (auto&& dst : grad_fn->dsts) {
if (!dst.grad_fn)
continue;
if (!dst.grad_fn->in_ref_keeper) {
// after grad_fn is cleared, refcnt of subsequent grad_fn
// could drop to 0
dst.grad_fn->in_ref_keeper = true;
ref_keeper.push_back(dst.grad_fn);
}
if (!dst.producer_record.next && dst->callback && dst->grad) {
// I'm the last grad producer, invoke callback
dst->callback(bctx.wrap_tensor(dst->grad));
}
}
grad_fn->clear();
} // finish tape loop
imperative::apply(GradBackward(self->m_key), {args.data(), args.size()});
} }


void GradKey::cleanup() {
active = false;
tape.clear();
for (intrusive_list::Iterator it(free_vars_head); it;) {
it->grad_fn.reset();
(it++)->unlink();
pybind11::function GradKeyWrapper::get_backward_closure(
GradKeyWrapper* self, py::list tensors) {
std::vector<ValueRef> args;
for (auto&& tensor : tensors) {
args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data());
} }
}

void GradKeyWrapper::backward(
std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) {
m_key->backward(std::move(tensors), std::move(grads));
auto closure = imperative::apply(GetBackwardColsure(self->m_key), args)[0]
.as<FunctionValue>();
auto py_function = [closure](std::vector<TensorWrapper*> tensors) {
std::vector<ValueRef> args;
for (auto* tw : tensors) {
args.push_back(tw->m_tensor->data());
}
(*closure)(args);
};
return pybind11::cpp_function(py_function);
} }


PyObject* GradKeyWrapper::get_name() { PyObject* GradKeyWrapper::get_name() {
return py::cast(m_key->name).release().ptr();
return py::cast(m_key->name()).release().ptr();
} }


void GradKeyWrapper::set_name(py::handle name) { void GradKeyWrapper::set_name(py::handle name) {
m_key->name = py::cast<std::string>(name);
m_key->name(py::cast<std::string>(name));
} }


PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) { PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) {
@@ -651,60 +106,39 @@ PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) {
PyErr_SetString(PyExc_TypeError, "expect Tensor"); PyErr_SetString(PyExc_TypeError, "expect Tensor");
return nullptr; return nullptr;
} }
if (tw->m_tensor->m_grad_info_dict.count(m_key.get())) {
if (imperative::apply(IsAttachedTo(m_key), tw->m_tensor->data())[0]
.cast<BoolValue>()) {
Py_RETURN_TRUE; Py_RETURN_TRUE;
} }
Py_RETURN_FALSE; Py_RETURN_FALSE;
} }


int GradKey::sm_min_priority = std::numeric_limits<int>::min();
GradKey::~GradKey() {
cleanup();
void GradKeyWrapper::enter() {
m_transformation = std::make_shared<GradTransformation>(m_key);
TransformationManager::get_instance().register_at<TransformationManager::Grad>(
m_transformation);
} }


std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry() {
static std::unordered_map<Typeinfo*, GradRuleFn> registry;
return registry;
void GradKeyWrapper::exit() {
TransformationManager::get_instance().unregister<TransformationManager::Grad>(
m_transformation);
m_transformation.reset();
} }


void GradInfoCollection::_shrink() {
auto pred = [](GradInfo& info) {
return !(info.grad_fn) || info.grad_fn->key.expired();
};
auto iter = std::remove_if(m_storage.begin(), m_storage.end(), pred);
m_storage.erase(iter, m_storage.end());
void GradKeyWrapper::suppress() {
m_transformation->suppress();
} }


bool GradInfoCollection::contains(GradKey* key) {
_shrink();
for (auto&& grad_info : m_storage) {
if (grad_info.grad_fn->key.lock().get() == key) {
return true;
}
}
return false;
void GradKeyWrapper::resume() {
m_transformation->resume();
} }


GradInfo& GradInfoCollection::operator[](GradKey* key) {
_shrink();
for (auto&& grad_info : m_storage) {
if (grad_info.grad_fn->key.lock().get() == key) {
return grad_info;
}
}
m_storage.emplace_back();
return m_storage.back();
GradKeyWrapper* GradKeyWrapper::get(std::shared_ptr<GradKey> key) {
return grad_key_map.at(key);
} }


GradInfo& GradInfoCollection::at(GradKey* key) {
_shrink();
for (auto&& grad_info : m_storage) {
if (grad_info.grad_fn->key.lock().get() == key) {
return grad_info;
}
}
mgb_assert(false);
GradKeyWrapper::~GradKeyWrapper() {
grad_key_map.erase(m_key);
} }


} // namespace mgb::imperative::python } // namespace mgb::imperative::python

+ 15
- 141
imperative/python/src/grad.h View File

@@ -12,166 +12,40 @@
#pragma once #pragma once


#include "./tensor.h" #include "./tensor.h"

#include "megbrain/imperative/ops/utility.h" #include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/transformations/grad.h"
#include "megbrain/utils/small_vector.h"


#include <megbrain/utils/small_vector.h>
#include <memory> #include <memory>
#include <optional> #include <optional>


namespace mgb::imperative::python { namespace mgb::imperative::python {


apply_result_t apply_grad(ApplyContext& ctx);

struct GradKey : std::enable_shared_from_this<GradKey>, NonCopyableObj {
std::string name;
bool active = true;
GradInfo::head_t free_vars_head;
std::vector<std::weak_ptr<GradFn>> tape;
int priority = 0;

~GradKey();

void attach(Tensor* tensor, pybind11::object callback);
void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>);
void cleanup();
bool is_blocked() const { return priority < sm_min_priority; }
inline static bool allow_higher_order_directive = false;

private:
static int sm_min_priority;
};

struct GradKeyWrapper {
struct GradKeyWrapper : NonCopyableObj {
using wrap_t = pyext17::wrap<GradKeyWrapper>; using wrap_t = pyext17::wrap<GradKeyWrapper>;
static constexpr auto tp_name = pybind11::detail::_("GradKey"); static constexpr auto tp_name = pybind11::detail::_("GradKey");


std::shared_ptr<GradKey> m_key; std::shared_ptr<GradKey> m_key;
std::shared_ptr<GradTransformation> m_transformation;


inline GradKeyWrapper() : m_key(std::make_shared<GradKey>()) {}
GradKeyWrapper();


PyObject* get_name(); PyObject* get_name();
void set_name(pybind11::handle name); void set_name(pybind11::handle name);
PyObject* get_priority();
void set_priority(pybind11::handle priority);
void attach(PyObject* const* args, size_t nargs); void attach(PyObject* const* args, size_t nargs);
void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>);
static void backward(GradKeyWrapper* self, pybind11::list, pybind11::list);
static pybind11::function get_backward_closure(
GradKeyWrapper* self, pybind11::list);
PyObject* is_attached_to(PyObject* const* args, size_t nargs); PyObject* is_attached_to(PyObject* const* args, size_t nargs);
void enter();
void exit();
void suppress();
void resume();
static GradKeyWrapper* get(std::shared_ptr<GradKey> key);
~GradKeyWrapper();
}; };


struct BackwardContext {
PyTypeObject* pytype = nullptr;

auto wrap_tensor(std::shared_ptr<Tensor> t) {
if (pytype) {
return TensorWrapper::make(pytype, std::move(t));
}
return TensorWrapper::make(std::move(t));
}

auto wrap_tensor(Tensor* t) { return wrap_tensor(t->shared_from_this()); }
};

struct CustomBackward {
using BackwardFn =
std::function<apply_result_t(BackwardContext&, Tensor* const*, size_t)>;
BackwardFn m_backward;
SmallVector<bool, 8> m_input_has_grad;
struct OutputAttr {
bool requires_grad = true, captured = true;
};
SmallVector<OutputAttr> m_output_attrs;

public:
template <typename T, typename R>
void operator()(BackwardContext& ctx, T&& grads, R&& receiver) {
size_t nargs = grads.size();
Tensor* args[nargs];
for (size_t i = 0; i < nargs; ++i) {
args[i] = grads[i];
}
auto ret = m_backward(ctx, args, nargs);
for (size_t i = 0; i < ret.size(); ++i) {
if (auto&& t = ret[i]) {
receiver(i, std::move(t));
}
}
}

bool input_has_grad(size_t i) { return m_input_has_grad[i]; }
bool output_requires_grad(size_t i) { return m_output_attrs[i].requires_grad; }
bool output_captured(size_t i) { return m_output_attrs[i].captured; }

class Maker {
bool output_size_set = false, input_has_grad_initialized = false;
CustomBackward& target;
ApplyContext& ctx;

void init_input_has_grad() {
if (!input_has_grad_initialized) {
input_has_grad_initialized = true;
target.m_input_has_grad.resize(ctx.nargs, true);
}
}

public:
Maker(CustomBackward& target_, ApplyContext& ctx_)
: target(target_), ctx(ctx_) {}

template <typename F>
Maker& backward(F&& f) {
mgb_assert(!target.m_backward);
target.m_backward = std::forward<F>(f);
return *this;
}
// mandatory
Maker& output_size(size_t sz) {
mgb_assert(!output_size_set);
output_size_set = true;
target.m_output_attrs.resize(sz);
return *this;
}
// optional, defaults to all true
Maker& input_has_grad(size_t i, bool v) {
init_input_has_grad();
target.m_input_has_grad.at(i) = v;
return *this;
}
// optional, defaults to all true
Maker& output_requires_grad(size_t i, bool v) {
target.m_output_attrs.at(i).requires_grad = v;
return *this;
}
// optional, defaults to all true
Maker& output_captured(size_t i, bool v) {
target.m_output_attrs.at(i).captured = v;
return *this;
}

void finalize() {
mgb_assert(output_size_set);
init_input_has_grad();
}
};

Maker maker(ApplyContext& ctx) { return {*this, ctx}; }
};

using GradRuleFn = std::function<std::optional<apply_result_t>(
ApplyContext&, CustomBackward::Maker&)>;

std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry();

inline bool input_requires_grad(const ApplyContext& ctx, size_t i) {
return !ctx.args[i]->m_grad_info_dict.empty();
}

struct GradRuleFallback : std::exception {};

template <typename T>
bool register_grad_rule(Typeinfo* typeinfo, T&& rule) {
return grad_rule_registry().emplace(typeinfo, std::forward<T>(rule)).second;
}

} // namespace mgb::imperative::python } // namespace mgb::imperative::python


namespace pybind11::detail { namespace pybind11::detail {


+ 0
- 43
imperative/python/src/grad_info.h View File

@@ -1,43 +0,0 @@
/**
* \file imperative/python/src/grad_info.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include <memory>

#include "./intrusive_list.h"

namespace mgb::imperative::python {

struct GradKey;
struct GradFn;
struct GradSlot;

struct GradSlotPtr {
std::shared_ptr<GradFn> grad_fn;
size_t idx;

operator bool() const;
GradSlot* operator->();
};

struct GradInfo : GradSlotPtr,
intrusive_list::Node<GradInfo, intrusive_list::before_t> {
GradInfo() = default;
GradInfo(GradInfo&) = default;
GradInfo(GradInfo&&) = default;
GradInfo& operator=(GradInfo&) = default;
GradInfo& operator=(GradInfo&&) = default;
GradInfo(const GradInfo& rhs) : GradInfo(const_cast<GradInfo&>(rhs)) {}
GradInfo& operator=(const GradInfo& rhs) {
return *this = const_cast<GradInfo&>(rhs);
}
};

} // namespace mgb::imperative::python

+ 237
- 164
imperative/python/src/grad_override.cpp View File

@@ -11,261 +11,334 @@


#include "./grad.h" #include "./grad.h"
#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/transformations/grad.h"


namespace mgb::imperative::python { namespace mgb::imperative::python {

class CustomGradMaker {
bool output_size_set = false, input_has_grad_initialized = false;
CustomBackward& target;
size_t nr_inputs;
void init_input_has_grad() {
if (!input_has_grad_initialized) {
input_has_grad_initialized = true;
target.m_input_has_grad.resize(nr_inputs, true);
}
}

public:
CustomGradMaker(CustomBackward& target, size_t nr_inputs)
: target(target), nr_inputs(nr_inputs) {}

CustomGradMaker& backward(CustomBackward::BackwardFn f) {
mgb_assert(!target.m_backward);
target.m_backward = f;
return *this;
}
// mandatory
CustomGradMaker& output_size(size_t sz) {
mgb_assert(!output_size_set);
output_size_set = true;
target.m_output_attrs.resize(sz);
return *this;
}
// optional, defaults to all true
CustomGradMaker& input_has_grad(size_t i, bool v) {
init_input_has_grad();
target.m_input_has_grad.at(i) = v;
return *this;
}
// optional, defaults to all true
CustomGradMaker& output_requires_grad(size_t i, bool v) {
target.m_output_attrs.at(i).requires_grad = v;
return *this;
}
// optional, defaults to all true
CustomGradMaker& output_captured(size_t i, bool v) {
target.m_output_attrs.at(i).captured = v;
return *this;
}
void finalize() {
mgb_assert(output_size_set);
init_input_has_grad();
}
};

namespace { namespace {


std::shared_ptr<Tensor> get_shape(Tensor* x) {
ValueRef get_shape(ValueRef x) {
static auto op = GetVarShape::make(); static auto op = GetVarShape::make();
return python::apply(op, x)[0];
return imperative::apply(*op, x)[0];
} }


std::shared_ptr<Tensor> reduce_to(Tensor* x, Tensor* s) {
ValueRef reduce_to(ValueRef x, ValueRef s) {
static auto op = Reduce::make(); static auto op = Reduce::make();
return python::apply(op, x, s)[0];
return imperative::apply(*op, x, s)[0];
} }


std::shared_ptr<Tensor> reshape_to(Tensor* x, Tensor* s) {
ValueRef reshape_to(ValueRef x, ValueRef s) {
static auto op = Reshape::make(); static auto op = Reshape::make();
return python::apply(op, x, s)[0];
return imperative::apply(*op, x, s)[0];
} }


std::shared_ptr<Tensor> broadcast_to(Tensor* x, Tensor* s) {
ValueRef broadcast_to(ValueRef x, ValueRef s) {
static auto op = Broadcast::make(); static auto op = Broadcast::make();
return python::apply(op, x, s)[0];
return imperative::apply(*op, x, s)[0];
} }


std::shared_ptr<Tensor> make_empty_tensor(CompNode cn, Tensor* shape, DType dtype) {
HostTensorND scalar{cn, {{1}, dtype}};
std::memset(scalar.raw_ptr(), 0, dtype.size());
interpreter::Interpreter::Handle handle = interpreter_for_py->put(scalar, false);
auto&& t = std::make_shared<Tensor>(handle);
auto res = broadcast_to(t.get(), shape);
ValueRef make_empty_tensor(
CompNodeValue::ref_t device, ValueRef shape, DTypeValue::ref_t dtype) {
HostTensorStorage storage(*device);
storage.ensure_size(dtype->size());
std::memset(storage.ptr(), 0, dtype->size());
auto t = imperative::apply(
CreateTensor(CreateTensor::Unique, *device, *dtype, ValueShape()),
HostStorage::make(storage))[0];
auto res = broadcast_to(t, shape);
return res; return res;
} }


std::optional<apply_result_t> elemwise_grad_rule(
ApplyContext& ctx, CustomBackward::Maker& maker) {
auto& op = ctx.op->cast_final_safe<Elemwise>();
if (op.mode == Elemwise::Mode::ADD) {
mgb_assert(ctx.nargs == 2);
std::array<std::shared_ptr<Tensor>, 2> input_shapes;
std::optional<std::vector<ValueRef>> elemwise_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto& elemwise = op.cast_final_safe<Elemwise>();
if (elemwise.mode != Elemwise::Mode::ADD) {
return {};
}
mgb_assert(inputs.size() == 2);
std::array<ValueRef, 2> input_shapes;
for (size_t i = 0; i < 2; ++i) {
if (inputs_require_grad[i]) {
input_shapes[i] = get_shape(inputs[i]);
}
}
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false);
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
std::vector<ValueRef> ret(2);
if (!grad) {
return ret;
}
for (size_t i = 0; i < 2; ++i) { for (size_t i = 0; i < 2; ++i) {
if (input_requires_grad(ctx, i)) {
input_shapes[i] = get_shape(ctx.args[i]);
if (shapes[i]) {
ret[i] = reduce_to(grad, shapes[i]);
} }
} }
maker.output_size(1).output_captured(0, false);
maker.backward([shapes = std::move(input_shapes)](
BackwardContext&, Tensor* const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(2);
if (!grad) {
return ret;
}
for (size_t i = 0; i < 2; ++i) {
if (shapes[i]) {
ret[i] = reduce_to(grad, shapes[i].get());
}
}
return ret;
});
return apply(ctx);
}
return {};
return ret;
});
maker.finalize();
return imperative::apply(ApplyOp(op), inputs);
} }


std::optional<apply_result_t> reshape_grad_rule(
ApplyContext& ctx, CustomBackward::Maker& maker) {
mgb_assert(ctx.nargs == 2);
std::array<std::shared_ptr<Tensor>, 2> input_shapes;
std::optional<std::vector<ValueRef>> reshape_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
mgb_assert(inputs.size() == 2);
std::array<ValueRef, 2> input_shapes;
for (size_t i = 0; i < 2; ++i) { for (size_t i = 0; i < 2; ++i) {
if (input_requires_grad(ctx, i)) {
input_shapes[i] = get_shape(ctx.args[i]);
if (inputs_require_grad[i]) {
input_shapes[i] = get_shape(inputs[i]);
} }
} }
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false); maker.output_size(1).output_captured(0, false);
maker.backward([shapes = std::move(input_shapes)](
BackwardContext&, Tensor* const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(2);
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
std::vector<ValueRef> ret(2);
if (!grad) { if (!grad) {
return ret; return ret;
} }
for (size_t i = 0; i < 2; ++i) { for (size_t i = 0; i < 2; ++i) {
if (shapes[i]) { if (shapes[i]) {
ret[i] = reshape_to(grad, shapes[i].get());
ret[i] = reshape_to(grad, shapes[i]);
} }
} }
return ret; return ret;
}); });
return apply(ctx);
maker.finalize();
return imperative::apply(ApplyOp(op), inputs);
} }


std::optional<apply_result_t> subtensor_grad_rule(
ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<Subtensor>();
auto&& grad_op = SetSubtensor::make(op.items);
SmallVector<std::shared_ptr<Tensor>> inputs;
if (input_requires_grad(ctx, 0)) {
inputs.push_back(get_shape(ctx.args[0]));
for (size_t i = 1; i < ctx.nargs; ++i) {
inputs.push_back(ctx.args[i]->copy());
std::optional<std::vector<ValueRef>> subtensor_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto&& subtensor = op.cast_final_safe<Subtensor>();
auto&& grad_op = SetSubtensor::make(subtensor.items);
SmallVector<ValueRef> inputs2;
if (inputs_require_grad[0]) {
inputs2.push_back(get_shape(inputs[0]));
for (size_t i = 1; i < inputs.size(); ++i) {
inputs2.push_back(inputs[i]);
} }
} }
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false); maker.output_size(1).output_captured(0, false);
maker.backward([inputs = std::move(inputs), grad_op_ = std::move(grad_op)](
BackwardContext&, Tensor* const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(1);
maker.backward([inputs = std::move(inputs2),
grad_op_ = std::move(grad_op)](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
std::vector<ValueRef> ret(1);
if (grad && inputs[0]) { if (grad && inputs[0]) {
SmallVector<Tensor*> args_(inputs.size() + 1);
auto&& zeros = make_empty_tensor(
grad->comp_node(), inputs[0].get(), grad->dtype());
args_[0] = zeros.get();
SmallVector<ValueRef> args_(inputs.size() + 1);
auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype());
args_[0] = zeros;
args_[1] = grad; args_[1] = grad;
for (size_t i = 1; i < inputs.size(); ++i) { for (size_t i = 1; i < inputs.size(); ++i) {
args_[i + 1] = inputs[i].get();
args_[i + 1] = inputs[i];
} }
ret[0] = python::apply(grad_op_, args_)[0];
ret[0] = imperative::apply(ApplyOp(*grad_op_), args_)[0];
} }
return ret; return ret;
}); });
return apply(ctx);
maker.finalize();
return imperative::apply(ApplyOp(op), inputs);
} }


std::optional<apply_result_t> indexingMultiAxisVec_grad_rule(
ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<IndexingMultiAxisVec>();
auto&& grad_op = IndexingSetMultiAxisVec::make(op.items);
SmallVector<std::shared_ptr<Tensor>> inputs;
if (input_requires_grad(ctx, 0)) {
inputs.push_back(get_shape(ctx.args[0]));
for (size_t i = 1; i < ctx.nargs; ++i) {
inputs.push_back(ctx.args[i]->copy());
std::optional<std::vector<ValueRef>> indexingMultiAxisVec_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto&& indexingMultiAxisVec = op.cast_final_safe<IndexingMultiAxisVec>();
auto&& grad_op = IndexingSetMultiAxisVec::make(indexingMultiAxisVec.items);
SmallVector<ValueRef> inputs2;
if (inputs_require_grad[0]) {
inputs2.push_back(get_shape(inputs[0]));
for (size_t i = 1; i < inputs.size(); ++i) {
inputs2.push_back(inputs[i]);
} }
} }
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false); maker.output_size(1).output_captured(0, false);
maker.backward([inputs = std::move(inputs), grad_op_ = std::move(grad_op)](
BackwardContext&, Tensor* const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(1);
maker.backward([inputs = std::move(inputs2),
grad_op_ = std::move(grad_op)](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
std::vector<ValueRef> ret(1);
if (grad && inputs[0]) { if (grad && inputs[0]) {
SmallVector<Tensor*> args_(inputs.size() + 1);
auto&& zeros = make_empty_tensor(
grad->comp_node(), inputs[0].get(), grad->dtype());
args_[0] = zeros.get();
SmallVector<ValueRef> args_(inputs.size() + 1);
auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype());
args_[0] = zeros;
args_[1] = grad; args_[1] = grad;
for (size_t i = 1; i < inputs.size(); ++i) { for (size_t i = 1; i < inputs.size(); ++i) {
args_[i + 1] = inputs[i].get();
args_[i + 1] = inputs[i];
} }
ret[0] = python::apply(grad_op_, args_)[0];
ret[0] = imperative::apply(ApplyOp(*grad_op_), args_)[0];
} }
return ret; return ret;
}); });
return apply(ctx);
maker.finalize();
return imperative::apply(ApplyOp(op), inputs);
} }


std::optional<apply_result_t> reduce_grad_rule(
ApplyContext& ctx, CustomBackward::Maker& maker) {
auto& op = ctx.op->cast_final_safe<Reduce>();
if (op.mode == Reduce::Mode::SUM) {
if (ctx.nargs != 1) {
return {};
}
std::array<std::shared_ptr<Tensor>, 1> input_shapes;
if (input_requires_grad(ctx, 0)) {
input_shapes[0] = get_shape(ctx.args[0]);
}
maker.output_size(1).output_captured(0, false);
maker.backward([shapes = std::move(input_shapes)](
BackwardContext&, Tensor* const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(1);
if (grad && shapes[0]) {
ret[0] = broadcast_to(grad, shapes[0].get());
}
return ret;
});
return apply(ctx);
std::optional<std::vector<ValueRef>> reduce_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto& reduce = op.cast_final_safe<Reduce>();
if (reduce.mode != Reduce::Mode::SUM) {
return {};
}
if (inputs.size() != 1) {
return {};
}
std::array<ValueRef, 1> input_shapes;
if (inputs_require_grad[0]) {
input_shapes[0] = get_shape(inputs[0]);
} }
return {};
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false);
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
std::vector<ValueRef> ret(1);
if (grad && shapes[0]) {
ret[0] = broadcast_to(grad, shapes[0]);
}
return ret;
});
maker.finalize();
return imperative::apply(ApplyOp(op), inputs);
} }


std::optional<apply_result_t> addAxis_grad_rule(
ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<AddAxis>();
mgb_assert(ctx.nargs == 1);
bool flag = input_requires_grad(ctx, 0);
auto&& grad_op = RemoveAxis::make(op.axis);
std::optional<std::vector<ValueRef>> addAxis_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto&& addAxis = op.cast_final_safe<AddAxis>();
mgb_assert(inputs.size() == 1);
bool flag = inputs_require_grad[0];
auto&& grad_op = RemoveAxis::make(addAxis.axis);
std::sort(grad_op->axis.begin(), grad_op->axis.end(), std::greater<int32_t>()); std::sort(grad_op->axis.begin(), grad_op->axis.end(), std::greater<int32_t>());
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false); maker.output_size(1).output_captured(0, false);
maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](
BackwardContext&, Tensor* const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(1);
maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
std::vector<ValueRef> ret(1);
if (grad && flag_) { if (grad && flag_) {
ret[0] = python::apply(grad_op_, grad)[0];
ret[0] = imperative::apply(*grad_op_, grad)[0];
} }
return ret; return ret;
}); });
return apply(ctx);
maker.finalize();
return imperative::apply(op, inputs);
} }


std::optional<apply_result_t> removeAxis_grad_rule(
ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<RemoveAxis>();
mgb_assert(ctx.nargs == 1);
bool flag = input_requires_grad(ctx, 0);
auto&& grad_op = AddAxis::make(op.axis);
std::optional<std::vector<ValueRef>> removeAxis_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto&& removeAxis = op.cast_final_safe<RemoveAxis>();
mgb_assert(inputs.size() == 1);
bool flag = inputs_require_grad[0];
auto&& grad_op = AddAxis::make(removeAxis.axis);
std::sort(grad_op->axis.begin(), grad_op->axis.end()); std::sort(grad_op->axis.begin(), grad_op->axis.end());
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false); maker.output_size(1).output_captured(0, false);
maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](
BackwardContext&, Tensor* const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(1);
maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
std::vector<ValueRef> ret(1);
if (grad && flag_) { if (grad && flag_) {
ret[0] = python::apply(grad_op_, grad)[0];
ret[0] = imperative::apply(*grad_op_, grad)[0];
} }
return ret; return ret;
}); });
return apply(ctx);
maker.finalize();
return imperative::apply(op, inputs);
} }


std::optional<apply_result_t> fastpathcopy_grad_rule(
ApplyContext& ctx, CustomBackward::Maker& maker) {
mgb_assert(ctx.nargs == 1);
std::optional<std::vector<ValueRef>> fastpathcopy_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
mgb_assert(inputs.size() == 1);
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false); maker.output_size(1).output_captured(0, false);
maker.backward([](BackwardContext&, Tensor* const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(1);
maker.backward([](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
std::vector<ValueRef> ret(1);
if (grad) { if (grad) {
ret[0] = grad->shared_from_this();
ret[0] = grad;
} }
return ret; return ret;
}); });
return apply(ctx);
maker.finalize();
return imperative::apply(op, inputs);
} }


struct Init { struct Init {
Init() { Init() {
auto& reg = grad_rule_registry();
reg.emplace(Elemwise::typeinfo(), elemwise_grad_rule);
reg.emplace(Reshape::typeinfo(), reshape_grad_rule);
reg.emplace(Subtensor::typeinfo(), subtensor_grad_rule);
reg.emplace(IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule);
reg.emplace(Reduce::typeinfo(), reduce_grad_rule);
reg.emplace(AddAxis::typeinfo(), addAxis_grad_rule);
reg.emplace(RemoveAxis::typeinfo(), removeAxis_grad_rule);
reg.emplace(FastpathCopy::typeinfo(), fastpathcopy_grad_rule);
CustomBackward::register_grad_rule(Elemwise::typeinfo(), elemwise_grad_rule);
CustomBackward::register_grad_rule(Reshape::typeinfo(), reshape_grad_rule);
CustomBackward::register_grad_rule(Subtensor::typeinfo(), subtensor_grad_rule);
CustomBackward::register_grad_rule(
IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule);
CustomBackward::register_grad_rule(Reduce::typeinfo(), reduce_grad_rule);
CustomBackward::register_grad_rule(AddAxis::typeinfo(), addAxis_grad_rule);
CustomBackward::register_grad_rule(
RemoveAxis::typeinfo(), removeAxis_grad_rule);
CustomBackward::register_grad_rule(
FastpathCopy::typeinfo(), fastpathcopy_grad_rule);
} }
} _; } _;




+ 0
- 245
imperative/python/src/intrusive_list.h View File

@@ -1,245 +0,0 @@
/**
* \file imperative/python/src/intrusive_list.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "megbrain/utils/metahelper.h"

namespace mgb::imperative::python::intrusive_list {

// copy policy
struct after_t {};
struct before_t {};
struct disable_t {};

template <typename T>
struct Tail;

// invariant: next->prev == this
template <typename T>
struct Head {
Tail<T>* next;

Head(Tail<T>* node = nullptr) : next(node) {}
Head(const Head<T>&) = delete;
Head<T>& operator=(const Head<T>&) = delete;
Head(Head<T>&& rhs) : next(rhs.next) {
rhs.next = nullptr;
if (next) {
next->prev = this;
}
}
Head<T>& operator=(Head<T>&& rhs) {
mgb_assert(!next);
next = rhs.next;
rhs.next = nullptr;
if (next) {
next->prev = this;
}
return *this;
}
~Head() {
if (next) {
next->prev = nullptr;
}
}
};

// invariant: prev->next == this
template <typename T>
struct Tail {
Head<T>* prev;

Tail(Head<T>* node = nullptr) : prev(node) {}
Tail(const Tail<T>&) = delete;
Tail<T>& operator=(const Tail<T>&) = delete;
Tail(Tail<T>&& rhs) : prev(rhs.prev) {
rhs.prev = nullptr;
if (prev) {
prev->next = this;
}
}
Tail<T>& operator=(Tail<T>&& rhs) {
mgb_assert(!prev);
prev = rhs.prev;
rhs.prev = nullptr;
if (prev) {
prev->next = this;
}
return *this;
}
~Tail() {
if (prev) {
prev->next = nullptr;
}
}
};

template <typename T, typename policy>
struct Node;

template <typename T>
class Iterator {
T* ptr;

void inc() { ptr = static_cast<T*>(ptr->Head<T>::next); }
void dec() { ptr = static_cast<T*>(ptr->Head<T>::prev); }

public:
Iterator(Head<T>& head) : ptr(static_cast<T*>(head.next)) {}
Iterator(Tail<T>& tail) : ptr(static_cast<T*>(tail.prev)) {}

template <typename policy>
Iterator(Node<T, policy>& node) : ptr(static_cast<T*>(&node)) {}

T& operator*() { return *static_cast<T*>(ptr); }
T* operator->() { return static_cast<T*>(ptr); }

operator bool() { return ptr; }
bool operator==(const Iterator<T>& rhs) { return ptr == rhs.ptr; }

Iterator& operator++() {
inc();
return *this;
}
Iterator& operator--() {
dec();
return *this;
}
Iterator operator++(int) {
auto ret = *this;
inc();
return ret;
}
Iterator operator--(int) {
auto ret = *this;
dec();
return ret;
}
};

// Node in a doubly linked list. Unlike std::list, nodes are not owned by a container.
// Instead, nodes may join or leave a list freely.
// NOTE: Derived classes have to explicitly declare copy / assignment as default,
// otherwise the compiler generated version would use the const T& signature,
// which is deleted.
template <typename T = void, typename policy = disable_t>
struct Node : Tail<std::conditional_t<std::is_same_v<T, void>, Node<T, policy>, T>>,
Head<std::conditional_t<std::is_same_v<T, void>, Node<T, policy>, T>> {
private:
using this_t = Node<T, policy>;
using U = std::conditional_t<std::is_same_v<T, void>, this_t, T>;

public:
using head_t = Head<U>;
using tail_t = Tail<U>;
using head_t::next;
using tail_t::prev;

Node() = default;
Node(const this_t&) = delete;
this_t& operator=(const this_t&) = delete;

//! constructed node is inserted after the input node
Node(after_t, head_t& node) : tail_t(&node), head_t(node.next) {
node.next = this;
if (next) {
next->prev = this;
}
}

//! constructed node is inserted before the input node
Node(before_t, tail_t& node) : head_t(&node), tail_t(node.prev) {
node.prev = this;
if (prev) {
prev->next = this;
}
}

Node(this_t&& rhs) : tail_t(rhs.prev), head_t(rhs.next) {
rhs.prev = nullptr;
rhs.next = nullptr;
if (prev) {
prev->next = this;
}
if (next) {
next->prev = this;
}
}

Node& operator=(this_t&& rhs) {
unlink();
prev = rhs.prev;
next = rhs.next;
rhs.prev = nullptr;
rhs.next = nullptr;
if (prev) {
prev->next = this;
}
if (next) {
next->prev = this;
}
return *this;
}

template <
typename p = policy,
typename = std::enable_if_t<
std::is_same_v<p, before_t> || std::is_same_v<p, after_t>, void>>
Node(this_t& rhs) : Node(policy{}, rhs) {}

template <
typename p = policy,
typename = std::enable_if_t<
std::is_same_v<p, before_t> || std::is_same_v<p, after_t>, void>>
this_t& operator=(this_t& rhs) {
insert(policy{}, rhs);
return *this;
}

void unlink() {
if (prev) {
prev->next = next;
}
if (next) {
next->prev = prev;
}
prev = nullptr;
next = nullptr;
}

//! this node is unlinked from its list and inserted after the input node
void insert(after_t, head_t& node) {
unlink();
prev = &node;
next = node.next;
node.next = this;
if (next) {
next->prev = this;
}
}

//! this node is unlinked from its list and inserted before the input node
void insert(before_t, tail_t& node) {
unlink();
next = &node;
prev = node.prev;
node.prev = this;
if (prev) {
prev->next = this;
}
}

void insert_before(tail_t& node) { insert(before_t{}, node); }
void insert_after(head_t& node) { insert(after_t{}, node); }

~Node() { unlink(); }
};

} // namespace mgb::imperative::python::intrusive_list

+ 0
- 42
imperative/python/src/module_trace.cpp View File

@@ -1,42 +0,0 @@
/**
* \file imperative/python/src/module_trace.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "./module_trace.h"
#include "./helper.h" // include op pybind11 caster

namespace py = pybind11;

namespace mgb::imperative::python {

apply_result_t apply_module_trace(ApplyContext& ctx) {
apply_result_t outputs;

auto args = py::tuple(ctx.nargs + 1);
args[0] = py::cast(ctx.op);
for (size_t i = 0; i < ctx.nargs; i++) {
args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this());
}
auto pyout = PyObject_Call(cpp_apply_module_trace, args.ptr(), nullptr);
if (!pyout)
throw py::error_already_set();
auto ret = py::reinterpret_steal<py::object>(pyout);

// assumption: python function always returns PyList
auto tup = py::reinterpret_borrow<py::list>(ret);
for (auto i = 0; i < tup.size(); i++) {
auto tw = TensorWrapper::try_cast(tup[i].ptr());
outputs.emplace_back(tw->m_tensor);
}
return outputs;
}

} // namespace mgb::imperative::python

+ 41
- 1
imperative/python/src/module_trace.h View File

@@ -11,10 +11,50 @@


#pragma once #pragma once


#include "megbrain/imperative/transformations/trace.h"
#include "megbrain/imperative/utils/map.h"

#include "./tensor.h" #include "./tensor.h"


namespace mgb::imperative::python { namespace mgb::imperative::python {


apply_result_t apply_module_trace(ApplyContext& ctx);
namespace py = pybind11;

class ModuleTraceTransformation final : public Transformation {
private:
py::function m_hook_fn;
int m_enabled = 0;

std::vector<ValueRef> apply_module_trace_hook(
const OpDef& op, Span<ValueRef> input_values) {
py::list input_tws;
for (auto&& input_value : input_values) {
input_tws.append(TensorWrapper::make(py_tensor_type, input_value));
}
py::list output_tws = m_hook_fn(py::cast(op.shared_from_this()), *input_tws);
std::vector<ValueRef> outputs;
for (auto&& output_tw : output_tws) {
outputs.push_back(
TensorWrapper::try_cast(output_tw.ptr())->m_tensor->data());
}
return outputs;
}

public:
ModuleTraceTransformation(py::function hook_fn) : m_hook_fn(hook_fn) {}
std::vector<ValueRef> apply_transformation(
const Operator& op, Span<ValueRef> inputs) override {
if (op.is<ApplyOp>() && m_enabled > 0) {
auto outputs = apply_module_trace_hook(op.cast<ApplyOp>().op(), inputs);
return outputs;
} else {
return imperative::apply(op, inputs);
}
}

ValueRef unwrap(ValueRef value) override { return value; }

std::string name() const override { return "ModuleTraceTransformation"; }
};


} // namespace mgb::imperative::python } // namespace mgb::imperative::python

+ 2
- 7
imperative/python/src/ops.cpp View File

@@ -185,7 +185,8 @@ int py_set_scope(PyObject* obj, PyObject* value, void* /* closure */) {
} }


PyGetSetDef PyOp(OpDef)::py_getsetters[] = { PyGetSetDef PyOp(OpDef)::py_getsetters[] = {
{const_cast<char*>("scope"), py_get_scope, py_set_scope, "scope", NULL},
{const_cast<char*>("scope"), py_get_scope, py_set_scope,
const_cast<char*>("scope"), NULL},
{NULL}}; {NULL}};


Py_hash_t PyOp(OpDef)::tp_hash(PyObject* obj) { Py_hash_t PyOp(OpDef)::tp_hash(PyObject* obj) {
@@ -556,12 +557,6 @@ void init_ops(py::module m) {
m.def( m.def(
"delete_rng_handle", "delete_rng_handle",
[](size_t handle) { [](size_t handle) {
// RNG op might execute after handle released due to async dispatch, so
// we need sync before delete a handle to avoid memory leak or
// use-after-free
if (python::interpreter_for_py->check_available()) {
python::interpreter_for_py->sync();
}
mgb::CompNode::sync_all(); mgb::CompNode::sync_all();
py_task_q.wait_all_task_finish(); py_task_q.wait_all_task_finish();
rng::delete_handle(handle); rng::delete_handle(handle);


+ 493
- 541
imperative/python/src/tensor.cpp
File diff suppressed because it is too large
View File


+ 38
- 216
imperative/python/src/tensor.h View File

@@ -20,6 +20,8 @@
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"


#include "./pyext17.h" #include "./pyext17.h"
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/utils/span.h"


namespace mgb::imperative::python { namespace mgb::imperative::python {


@@ -32,126 +34,67 @@ struct ObjectPtr : B {


} // namespace mgb::imperative::python } // namespace mgb::imperative::python


#include "./grad_info.h" // for struct GradInfo
#include "./trace_info.h" // for struct TraceInfo

namespace mgb::imperative::python { namespace mgb::imperative::python {


struct GradKey;

extern interpreter::Interpreter::Channel* interpreter_for_py; extern interpreter::Interpreter::Channel* interpreter_for_py;
extern PyTypeObject* py_tensor_type;


class SharedHandle {
using Handle = interpreter::Interpreter::Handle;
static_assert(std::is_pointer_v<Handle>);
std::shared_ptr<std::remove_pointer_t<Handle>> holder;

public:
inline explicit SharedHandle(Handle handle)
: holder(handle, [](auto* h) {
if (h) {
interpreter_for_py->del(h);
}
}) {}
SharedHandle(const SharedHandle&) = default;
SharedHandle& operator=(const SharedHandle&) = default;
SharedHandle(SharedHandle&&) = default;
SharedHandle& operator=(SharedHandle&&) = default;

inline Handle get() { return holder.get(); }
};

// impl in grad.cpp
class GradInfoCollection {
struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
private: private:
SmallVector<GradInfo> m_storage;

protected:
void _shrink();
std::string m_name;
ValueRef m_data;


public: public:
bool contains(GradKey* key);
GradInfo& operator[](GradKey* key);
GradInfo& at(GradKey* key);
bool empty() {
_shrink();
return m_storage.empty();
}
auto begin() {
_shrink();
return m_storage.begin();
}
auto end() {
_shrink();
return m_storage.end();
}
size_t count(GradKey* key) { return contains(key) ? 1 : 0; }
};

struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
using flags_t = uint64_t;

struct Flags {
static constexpr flags_t SCALAR = 1;
static constexpr flags_t GRAD = 1 << 1;
static constexpr flags_t TRACE = 1 << 2;
static constexpr flags_t MODULE_TRACE = 1 << 3;
};

flags_t m_flags = 0;

GradInfoCollection m_grad_info_dict;
TraceInfo m_trace_info;
SharedHandle m_handle;
std::string user_custom_name;
std::string automatic_name;
cg::VarNode* m_var;
pybind11::object m_module_trace_info;

using Handle = interpreter::Interpreter::Handle; using Handle = interpreter::Interpreter::Handle;


inline Tensor() : m_handle(nullptr), m_var(nullptr) {}
inline explicit Tensor(Handle handle) : m_handle(handle), m_var(nullptr) {}
inline explicit Tensor(SharedHandle handle)
: m_handle(std::move(handle)), m_var(nullptr) {}
inline explicit Tensor(cg::VarNode* var) : m_handle(nullptr), m_var(var) {}
inline explicit Tensor(ValueRef data) : m_data{data} {}


~Tensor() = default; ~Tensor() = default;


inline std::shared_ptr<Tensor> copy() { inline std::shared_ptr<Tensor> copy() {
auto ret = std::make_shared<Tensor>(m_handle);
ret->m_flags = m_flags;
ret->m_grad_info_dict = m_grad_info_dict;
ret->m_trace_info = m_trace_info;
ret->m_var = m_var;
auto ret = std::make_shared<Tensor>(m_data.unwrap());
ret->m_name = m_name;
return ret; return ret;
} }


inline DType dtype() {
if (m_var) {
return m_var->dtype();
inline DType dtype() { return *data().dtype(); }
inline CompNode comp_node() { return *data().device(); }
inline std::optional<ValueShape> shape() {
auto shape = data().shape();
if (!shape) {
return {};
} }
return interpreter_for_py->get_dtype(m_handle.get());
return *shape;
} }
inline CompNode comp_node() {
if (m_var) {
return m_var->comp_node();
inline HostValue::ref_t numpy() { return data().numpy(); }
inline void reset(ValueRef value) {
m_data = value;
if (!m_name.empty()) {
set_name(m_name);
} }
return interpreter_for_py->get_device(m_handle.get());
} }
inline TensorShape shape() {
if (m_var) {
return m_var->shape();
inline ValueRef data() { return m_data.unwrap(); }
bool is_scalar() { return data().is_scalar(); }
inline std::string name() { return m_name; }
inline void set_name(std::string name) {
m_name = name;
if (!name.empty()) {
auto output = imperative::apply(RenameValue(name), m_data)[0];
m_data = output;
} }
return interpreter_for_py->get_shape(m_handle.get());
} }
}; };


struct TensorWrapper { struct TensorWrapper {
public:
std::shared_ptr<Tensor> m_tensor; std::shared_ptr<Tensor> m_tensor;


inline TensorWrapper(std::shared_ptr<Tensor> tensor = {}) inline TensorWrapper(std::shared_ptr<Tensor> tensor = {})
: m_tensor(std::move(tensor)) {}
: m_tensor(std::move(tensor)) {
mgb_assert(tensor, "empty storage");
}

inline TensorWrapper(ValueRef value) : m_tensor(std::make_shared<Tensor>(value)) {}
TensorWrapper(PyObject* args, PyObject* kwargs); TensorWrapper(PyObject* args, PyObject* kwargs);
~TensorWrapper() = default; ~TensorWrapper() = default;


@@ -191,33 +134,17 @@ struct TensorWrapper {
void reset(PyObject*); void reset(PyObject*);
PyObject* detach(); PyObject* detach();
PyObject* isscalar(); PyObject* isscalar();
void setscalar();
void unsetscalar();
PyObject* _dev_tensor(); PyObject* _dev_tensor();
void _drop(); void _drop();
PyObject* varnode(); PyObject* varnode();
void reset_varnode();
PyObject* handle();
void set_handle(PyObject*);

PyObject* mixin_handle();
PyObject* recording(); PyObject* recording();
PyObject* copied(); PyObject* copied();

void set_mixin_handle(PyObject*);
void set_recording(PyObject*);

PyObject* compiled_info();
void set_compiled_info(PyObject*);
PyObject* trace_mixin_info();
void set_trace_mixin_info(PyObject*);
PyObject* module_trace_info(); PyObject* module_trace_info();
void set_module_trace_info(PyObject*); void set_module_trace_info(PyObject*);
PyObject* user_custom_name();
void set_user_custom_name(PyObject*);
PyObject* automatic_name();
void set_automatic_name(PyObject*);
void _set_name(PyObject*);
PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); }; PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); };
PyObject* _detail();
void _watch();
}; };


struct PySymbolVar { struct PySymbolVar {
@@ -230,113 +157,8 @@ struct PySymbolVar {
PyObject* py_apply( PyObject* py_apply(
PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */); PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */);


struct ApplyContext {
static Tensor::flags_t global_disable;
static Tensor::flags_t global_enable;

Tensor::flags_t flags = 0;
std::shared_ptr<OpDef> op;
Tensor* const* args;
size_t nargs;
PyTypeObject* pytype = nullptr;
bool backward = false;

class scoped_disable : NonCopyableObj {
Tensor::flags_t saved_flags;

public:
scoped_disable(Tensor::flags_t flags)
: saved_flags(ApplyContext::global_disable) {
ApplyContext::global_disable |= flags;
}
~scoped_disable() { ApplyContext::global_disable = saved_flags; }
};
};

using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>;

apply_result_t apply(ApplyContext& ctx);

template <typename T>
decltype(auto) resolve_arrow(T&& p) {
if constexpr (std::is_pointer_v<std::remove_reference_t<T>>) {
auto* ret = p;
return ret;
} else {
auto probe = [](auto&& p) -> decltype(p.operator->()) {};
if constexpr (std::is_invocable_v<decltype(probe), decltype(p)>) {
return resolve_arrow(p.operator->());
} else {
return std::forward<T>(p);
}
}
}

template <typename... Args>
constexpr bool is_all_tensor_ptr =
(... && std::is_same_v<decltype(resolve_arrow(std::declval<Args>())), Tensor*>);

template <typename... Args, std::enable_if_t<is_all_tensor_ptr<Args...>, int> = 0>
apply_result_t apply(std::shared_ptr<OpDef> op, Args&&... args) {
ApplyContext ctx;
Tensor* arg_arr[] = {resolve_arrow(args)...};
ctx.flags = (0 | ... | args->m_flags);
ctx.args = arg_arr;
ctx.nargs = sizeof...(args);
ctx.op = std::move(op);
return apply(ctx);
}

inline auto apply(std::shared_ptr<OpDef> op, Tensor* const* args, size_t nargs) {
ApplyContext ctx;
ctx.op = std::move(op);
ctx.nargs = nargs;
ctx.args = args;
for (size_t i = 0; i < nargs; ++i) {
ctx.flags |= args[i]->m_flags;
}
return apply(ctx);
}

template <typename T>
auto apply(std::shared_ptr<OpDef> op, T&& tensors) -> std::enable_if_t<
std::is_same_v<decltype(resolve_arrow(tensors[0])), Tensor*>, apply_result_t> {
size_t nargs = tensors.size();
Tensor* args[nargs];
for (size_t i = 0; i < nargs; ++i) {
args[i] = resolve_arrow(tensors[i]);
}
return apply(op, args, nargs);
}

std::shared_ptr<Tensor> make_const(imperative::TensorPtr value);

inline auto apply(Subgraph graph, Tensor* const* args, size_t nargs) {
SmallVector<std::shared_ptr<Tensor>> inputs;
for (size_t i = 0; i < nargs; ++i) {
inputs.push_back(args[i]->shared_from_this());
}
auto apply_functor = [](std::shared_ptr<OpDef> op,
SmallVector<std::shared_ptr<Tensor>> inputs,
size_t) { return apply(op, std::move(inputs)); };
return graph.apply(inputs, apply_functor, &make_const);
}

template <typename T>
auto apply(Subgraph graph, T&& tensors) -> std::enable_if_t<
std::is_same_v<std::decay_t<decltype(tensors[0])>, Tensor*>, apply_result_t> {
size_t nargs = tensors.size();
Tensor* args[nargs];
for (size_t i = 0; i < nargs; ++i) {
args[i] = resolve_arrow(tensors[i]);
}
return apply(graph, args, nargs);
}

void init_tensor(pybind11::module); void init_tensor(pybind11::module);


extern PyObject* cpp_apply_with_tracing;
extern PyObject* cpp_apply_backward_varnode;
extern PyObject* cpp_apply_module_trace; extern PyObject* cpp_apply_module_trace;


} // namespace mgb::imperative::python } // namespace mgb::imperative::python


+ 0
- 63
imperative/python/src/trace.cpp View File

@@ -1,63 +0,0 @@
/**
* \file imperative/python/src/trace.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "./trace.h"
#include "./helper.h"
#include "megbrain/imperative/ops/autogen.h"

namespace py = pybind11;

namespace mgb::imperative::python {

apply_result_t apply_trace(ApplyContext& ctx) {
apply_result_t outputs;

if (ctx.backward) {
// reach here when compiled=True
auto args = py::tuple(ctx.nargs + 1);
args[0] = py::cast(ctx.op);
for (size_t i = 0; i < ctx.nargs; i++) {
args[i + 1] = py::cast(ctx.args[i]->m_var);
}
py::object pyout = py::reinterpret_steal<py::object>(
PyObject_Call(cpp_apply_backward_varnode, args.ptr(), nullptr));
if (!pyout)
throw py::error_already_set();

// assumption: python function always returns PyList
auto tup = py::reinterpret_borrow<py::list>(pyout);
for (size_t i = 0; i < tup.size(); i++) {
auto pitem = tup[i].cast<cg::VarNode*>();
outputs.emplace_back(std::make_shared<Tensor>(pitem));
}
return outputs;
}

auto args = py::tuple(ctx.nargs + 1);
args[0] = py::cast(ctx.op);
for (size_t i = 0; i < ctx.nargs; i++) {
args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this());
}
auto pyout = PyObject_Call(cpp_apply_with_tracing, args.ptr(), nullptr);
if (!pyout)
throw py::error_already_set();

// assumption: python function always returns PyList
auto tup = py::reinterpret_steal<py::list>(pyout);
for (size_t i = 0; i < tup.size(); i++) {
auto tw = TensorWrapper::try_cast(tup[i].ptr());
outputs.emplace_back(tw->m_tensor);
}
return outputs;
}

} // namespace mgb::imperative::python

+ 0
- 28
imperative/python/src/trace.h View File

@@ -1,28 +0,0 @@
/**
* \file imperative/python/src/trace.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include <stdexcept>
#include "./tensor.h"

namespace mgb::imperative::python {

class TraceReadError : public std::exception {
public:
explicit TraceReadError(const char* m) : message{m} {}
const char* what() const noexcept override { return message.c_str(); }

private:
std::string message = "";
};

apply_result_t apply_trace(ApplyContext& ctx);

} // namespace mgb::imperative::python

+ 0
- 49
imperative/python/src/trace_info.h View File

@@ -1,49 +0,0 @@
/**
* \file imperative/python/src/trace_info.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "Python.h"
#include "inttypes.h"

namespace mgb::imperative::python {

struct TraceInfo {
int64_t mixin_handle = -1;
bool recording = false;

// refer to CompiledTensorProxy in tracing.py, works from second trace step
PyObject* compiled_info = nullptr;
// refer to TensorInfo in tracing.py, only works in first trace step
PyObject* trace_mixin_info = nullptr;

TraceInfo() = default;

TraceInfo& operator=(const TraceInfo& that) {
mixin_handle = that.mixin_handle;
recording = that.recording;

trace_mixin_info = that.trace_mixin_info;
Py_XINCREF(trace_mixin_info);
compiled_info = that.compiled_info;
Py_XINCREF(compiled_info);

return *this;
}

~TraceInfo() {
Py_XDECREF(trace_mixin_info);
Py_XDECREF(compiled_info);
}

private:
TraceInfo(const TraceInfo& that) = default;
};

} // namespace mgb::imperative::python

+ 0
- 14
imperative/python/test/conftest.py View File

@@ -16,10 +16,6 @@ import megengine.module
from megengine import Parameter from megengine import Parameter
from megengine.core._imperative_rt.core2 import sync from megengine.core._imperative_rt.core2 import sync
from megengine.device import get_device_count from megengine.device import get_device_count
from megengine.experimental.autograd import (
disable_higher_order_directive,
enable_higher_order_directive,
)
from megengine.jit import trace as _trace from megengine.jit import trace as _trace
from megengine.module import Linear, Module from megengine.module import Linear, Module


@@ -45,13 +41,3 @@ def skip_distributed(request):
platform.system() platform.system()
) )
) )


@pytest.fixture(autouse=True)
def resolve_require_higher_order_directive(request):
marker = request.node.get_closest_marker("require_higher_order_directive")
if marker:
enable_higher_order_directive()
yield
if marker:
disable_higher_order_directive()

+ 1
- 1
imperative/python/test/integration/test_trace_dump.py View File

@@ -146,5 +146,5 @@ def test_dump_bn_train_mode():


data = mge.tensor(np.random.random((10, 10, 10, 10))) data = mge.tensor(np.random.random((10, 10, 10, 10)))
bn_train(data) bn_train(data)
with pytest.raises(AssertionError):
with pytest.raises(RuntimeError):
bn_train.dump("test.mge") bn_train.dump("test.mge")

+ 43
- 8
imperative/python/test/unit/autodiff/test_grad_manger.py View File

@@ -17,7 +17,7 @@ import megengine.distributed as dist
import megengine.functional as F import megengine.functional as F
import megengine.module as M import megengine.module as M
import megengine.optimizer as optim import megengine.optimizer as optim
from megengine.autodiff import GradManager
from megengine.autodiff import Function, GradManager
from megengine.jit import trace from megengine.jit import trace




@@ -214,7 +214,7 @@ def test_remote_grad(trace_mode):
x = dist.functional.remote_recv(rank - 1) x = dist.functional.remote_recv(rank - 1)
y = m(x) y = m(x)
if rank != size - 1: if rank != size - 1:
dist.functional.remote_send(y, dest_rank=rank + 1)
x = dist.functional.remote_send(y, dest_rank=rank + 1)
gm.backward() gm.backward()
else: else:
y = y.mean() y = y.mean()
@@ -224,7 +224,7 @@ def test_remote_grad(trace_mode):
if trace_mode is not None: if trace_mode is not None:
train_func = trace(symbolic=trace_mode)(train_func) train_func = trace(symbolic=trace_mode)(train_func)


for i in range(3):
for i in range(1):
train_func(x) train_func(x)


worker() worker()
@@ -340,7 +340,6 @@ def test_broadcast_grad(trace_mode):
worker() worker()




@pytest.mark.require_higher_order_directive()
def test_2nd_grad_with_manager(): def test_2nd_grad_with_manager():
x_np = np.random.rand(10).astype("float32") x_np = np.random.rand(10).astype("float32")
x = mge.tensor(x_np) x = mge.tensor(x_np)
@@ -359,7 +358,6 @@ def test_2nd_grad_with_manager():
) )




@pytest.mark.require_higher_order_directive()
def test_grad_manager_group(): def test_grad_manager_group():
x_np = np.random.rand(10).astype("float32") x_np = np.random.rand(10).astype("float32")
x = mge.tensor(x_np) x = mge.tensor(x_np)
@@ -376,7 +374,6 @@ def test_grad_manager_group():
x.grad = None x.grad = None




@pytest.mark.require_higher_order_directive()
def test_grad_manager_group_visibility(): def test_grad_manager_group_visibility():
x_np = np.random.rand(10).astype("float32") x_np = np.random.rand(10).astype("float32")
x = mge.tensor(x_np) x = mge.tensor(x_np)
@@ -392,7 +389,6 @@ def test_grad_manager_group_visibility():
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5)




@pytest.mark.require_higher_order_directive()
def test_grad_manager_visibility_by_order(): def test_grad_manager_visibility_by_order():
x_np = np.random.rand(10).astype("float32") x_np = np.random.rand(10).astype("float32")
x = mge.tensor(x_np) x = mge.tensor(x_np)
@@ -410,7 +406,6 @@ def test_grad_manager_visibility_by_order():
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5)




@pytest.mark.require_higher_order_directive()
@pytest.mark.parametrize("target", [F.cos, F.sin, lambda x: x * 2 + 1]) @pytest.mark.parametrize("target", [F.cos, F.sin, lambda x: x * 2 + 1])
def test_emulate_forward_mode_with_reverse_mode(target): def test_emulate_forward_mode_with_reverse_mode(target):
def jvp(inp, expr): def jvp(inp, expr):
@@ -434,3 +429,43 @@ def test_emulate_forward_mode_with_reverse_mode(target):


np.testing.assert_almost_equal(y.numpy(), y1.numpy(), decimal=5) np.testing.assert_almost_equal(y.numpy(), y1.numpy(), decimal=5)
np.testing.assert_almost_equal(dy.numpy(), dy1.numpy(), decimal=3) np.testing.assert_almost_equal(dy.numpy(), dy1.numpy(), decimal=3)


def test_2nd_grad_with_custom_gradient():
class MySin(Function):
def forward(self, x):
self.inp = x
x = mge.Tensor(x.numpy())
y = F.sin(x)
return y

def backward(self, dy):
dx = F.cos(self.inp) * dy
return dx

class MyCos(Function):
def forward(self, x):
self.inp = x
x = mge.Tensor(x.numpy())
y = F.cos(x)
return y

def backward(self, dy):
dx = -MySin()(self.inp) * dy
return dx

x_np = np.random.rand(10).astype("float32")
x = mge.tensor(x_np)

gm = GradManager().attach([x])
gm2 = GradManager().attach([x])

with gm:
with gm2:
y = MyCos()(x)
gm2.backward(y)
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5)
gm.backward(x.grad)
np.testing.assert_almost_equal(
x.grad.numpy(), -np.sin(x_np) - np.cos(x_np), decimal=5
)

+ 162
- 161
imperative/python/test/unit/core/test_autodiff.py View File

@@ -7,8 +7,6 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import gc import gc
import platform
import weakref


import numpy as np import numpy as np
import pytest import pytest
@@ -60,24 +58,20 @@ def test_dist_grad():
def worker(): def worker():
rank = dist.get_rank() rank = dist.get_rank()
if rank == 0: if rank == 0:
grad = Grad()

x = as_tensor(x_np)
grad.wrt(x, callback=save_to(x))
# need a placeholder to trace operator
remote_send(x, 1)
recv_x = remote_recv(1)
y = recv_x * recv_x

grad([y], [as_tensor(np.ones_like(x_np))])
with Grad() as grad:
x = as_tensor(x_np)
grad.wrt(x, callback=save_to(x))
# need a placeholder to trace operator
remote_send(x, 1)
recv_x = remote_recv(1)
y = recv_x * recv_x
grad([y], [as_tensor(np.ones_like(x_np))])
np.testing.assert_almost_equal(x.grad.numpy(), x.numpy() * 2) np.testing.assert_almost_equal(x.grad.numpy(), x.numpy() * 2)
elif rank == 1: elif rank == 1:
grad = Grad()

recv_x = remote_recv(0)
remote_send(recv_x, 0)

grad([], [])
with Grad() as grad:
recv_x = remote_recv(0)
remote_send(recv_x, 0)
grad([], [])


worker() worker()


@@ -86,11 +80,11 @@ def test_grad():
x_np = np.random.rand(10).astype("float32") x_np = np.random.rand(10).astype("float32")
x = as_tensor(x_np) x = as_tensor(x_np)


grad = Grad().wrt(x, callback=save_to(x))

y = cos(x)
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
y = cos(x)
grad(y, as_tensor(np.ones_like(x_np)))


grad(y, as_tensor(np.ones_like(x_np)))
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np)) np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np))




@@ -98,12 +92,12 @@ def test_grad_2():
x_np = np.random.rand(10).astype("float32") x_np = np.random.rand(10).astype("float32")
x = as_tensor(x_np) x = as_tensor(x_np)


grad = Grad().wrt(x, callback=save_to(x))
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
y = mul(x, x)
y = mul(y, y)
grad(y, as_tensor(np.ones_like(x_np)))


y = mul(x, x)
y = mul(y, y)

grad(y, as_tensor(np.ones_like(x_np)))
np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6)




@@ -113,32 +107,31 @@ def test_2nd_grad():
x = as_tensor(x_np) x = as_tensor(x_np)
ones = as_tensor(np.ones_like(x_np)) ones = as_tensor(np.ones_like(x_np))


grad = Grad().wrt(x, callback=save_to(x))
grad._priority = -1
grad2 = Grad().wrt(x, callback=save_to(x))
grad2._priority = 0

y = cos(x)
with Grad("grad2") as grad2:
with Grad("grad") as grad:
grad2.wrt(x, callback=save_to(x))
grad.wrt(x, callback=save_to(x))
y = cos(x)
grad(y, ones)
z = x.grad
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5)


grad(y, ones)
z = x.grad
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5)
x.grad = None
grad2(z, ones)


x.grad = None
grad2(z, ones)
np.testing.assert_almost_equal(x.grad.numpy(), -np.cos(x_np), decimal=5)
np.testing.assert_almost_equal(x.grad.numpy(), -np.cos(x_np), decimal=5)




def test_grad_with_tensor_wrapper(): def test_grad_with_tensor_wrapper():
x_np = np.random.rand(10).astype("float32") x_np = np.random.rand(10).astype("float32")
x = mge.Tensor(x_np) x = mge.Tensor(x_np)


grad = Grad().wrt(x, callback=save_to(x))
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
y = mul(x, x)
y = mul(y, y)
grad(y, mge.Tensor(np.ones_like(x_np)))


y = mul(x, x)
y = mul(y, y)

grad(y, mge.Tensor(np.ones_like(x_np)))
np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6)




@@ -162,18 +155,21 @@ def test_release():


@check @check
def _(): def _():
g = Grad().wrt(x)
y = x * x
g(y, dy)
with Grad() as g:
g.wrt(x)
y = x * x
g(y, dy)


@check @check
def _(): def _():
with Grad().wrt(x):
with Grad() as g:
g.wrt(x)
pass pass


@check @check
def _(): def _():
with Grad().wrt(x):
with Grad() as g:
g.wrt(x)
y = x * x y = x * x




@@ -181,12 +177,12 @@ def test_grad_inplace():
x_np = np.random.rand(10).astype("float32") x_np = np.random.rand(10).astype("float32")
x = mge.Tensor(x_np) x = mge.Tensor(x_np)


grad = Grad().wrt(x, callback=save_to(x))
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
y = mul(x, x)
y *= y
grad(y, mge.Tensor(np.ones_like(x_np)))


y = mul(x, x)
y *= y

grad(y, mge.Tensor(np.ones_like(x_np)))
np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6)




@@ -196,11 +192,11 @@ def test_identity():
dy_np = np.random.rand(*x.shape).astype("float32") dy_np = np.random.rand(*x.shape).astype("float32")
dy = mge.Tensor(dy_np) dy = mge.Tensor(dy_np)


grad = Grad().wrt(x, callback=save_to(x))

(y,) = apply(Identity(), x)
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
(y,) = apply(Identity(), x)
grad(y, dy)


grad(y, dy)
np.testing.assert_array_equal(x.grad.numpy(), dy_np) np.testing.assert_array_equal(x.grad.numpy(), dy_np)




@@ -220,15 +216,14 @@ def test_elemwise_add():
refs["y"] = TensorWeakRef(y) refs["y"] = TensorWeakRef(y)
return x + y return x + y


grad = Grad().wrt(x, callback=save_to(x))

z = f(x, y)
del y
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
z = f(x, y)
del y
for k, r in refs.items():
assert r() is None
grad(z, dz)


for k, r in refs.items():
assert r() is None

grad(z, dz)
np.testing.assert_almost_equal(x.grad.numpy(), dz_np.sum(0) * 2, decimal=5) np.testing.assert_almost_equal(x.grad.numpy(), dz_np.sum(0) * 2, decimal=5)




@@ -245,13 +240,12 @@ def test_elemwise_relu():
refs["x"] = TensorWeakRef(x) refs["x"] = TensorWeakRef(x)
return relu(x) return relu(x)


grad = Grad().wrt(x, callback=save_to(x))
z = f(x)
assert refs["x"]() is None
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
z = f(x)
assert refs["x"]() is None
grad(z, dz)


grad(z, dz)
np.testing.assert_almost_equal(x.grad.numpy(), [2.0, 0]) np.testing.assert_almost_equal(x.grad.numpy(), [2.0, 0])




@@ -269,21 +263,21 @@ def test_reshape():
x_np = np.random.rand(2, 5).astype("float32") x_np = np.random.rand(2, 5).astype("float32")
x = mge.Tensor(x_np) x = mge.Tensor(x_np)


grad = Grad().wrt(x, callback=save_to(x))
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
refs = {}


refs = {}
def f(x):
x = x * 1
y = x.reshape(5, 2)
refs["x"] = TensorWeakRef(x)
return y


def f(x):
x = x * 1
y = x.reshape(5, 2)
refs["x"] = TensorWeakRef(x)
return y
y = f(x)
for _, r in refs.items():
assert r() is None
grad(y, F.ones_like(y))


y = f(x)
for _, r in refs.items():
assert r() is None

grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((2, 5), dtype=np.float32), x.grad.numpy()) np.testing.assert_equal(np.ones((2, 5), dtype=np.float32), x.grad.numpy())




@@ -291,21 +285,21 @@ def test_subtensor():
x_np = np.random.rand(3, 3).astype("float32") x_np = np.random.rand(3, 3).astype("float32")
x = mge.Tensor(x_np) x = mge.Tensor(x_np)


grad = Grad().wrt(x, callback=save_to(x))
refs = {}
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
refs = {}


def f(x):
x = x * 1
y = x[1:-1, :2]
refs["x"] = TensorWeakRef(x)
return y
def f(x):
x = x * 1
y = x[1:-1, :2]
refs["x"] = TensorWeakRef(x)
return y


y = f(x)
for _, r in refs.items():
assert r() is None
y = f(x)
for _, r in refs.items():
assert r() is None
grad(y, F.ones_like(y))


grad(y, F.ones_like(y))
np.testing.assert_equal( np.testing.assert_equal(
np.array([[0, 0, 0], [1, 1, 0], [0, 0, 0]], dtype=np.float32), x.grad.numpy() np.array([[0, 0, 0], [1, 1, 0], [0, 0, 0]], dtype=np.float32), x.grad.numpy()
) )
@@ -315,21 +309,21 @@ def test_IndexingMultiAxisVec():
x_np = np.random.rand(3, 3).astype("float32") x_np = np.random.rand(3, 3).astype("float32")
x = mge.Tensor(x_np) x = mge.Tensor(x_np)


grad = Grad().wrt(x, callback=save_to(x))
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
refs = {}


refs = {}
def f(x):
x = x * 1
y = x[[0, 2], [0, 2]]
refs["x"] = TensorWeakRef(x)
return y


def f(x):
x = x * 1
y = x[[0, 2], [0, 2]]
refs["x"] = TensorWeakRef(x)
return y
y = f(x)
for _, r in refs.items():
assert r() is None
grad(y, F.ones_like(y))


y = f(x)
for _, r in refs.items():
assert r() is None

grad(y, F.ones_like(y))
np.testing.assert_equal( np.testing.assert_equal(
np.array([[1, 0, 0], [0, 0, 0], [0, 0, 1]], dtype=np.float32), x.grad.numpy() np.array([[1, 0, 0], [0, 0, 0], [0, 0, 1]], dtype=np.float32), x.grad.numpy()
) )
@@ -339,21 +333,21 @@ def test_AxisAddRemove():
x_np = np.random.rand(1, 5).astype("float32") x_np = np.random.rand(1, 5).astype("float32")
x = mge.Tensor(x_np) x = mge.Tensor(x_np)


grad = Grad().wrt(x, callback=save_to(x))
refs = {}
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
refs = {}


def f(x):
x = x * 1
y = F.squeeze(F.expand_dims(x, 2), 0)
refs["x"] = TensorWeakRef(x)
return y
def f(x):
x = x * 1
y = F.squeeze(F.expand_dims(x, 2), 0)
refs["x"] = TensorWeakRef(x)
return y


y = f(x)
for _, r in refs.items():
assert r() is None
y = f(x)
for _, r in refs.items():
assert r() is None
grad(y, F.ones_like(y))


grad(y, F.ones_like(y))
np.testing.assert_equal( np.testing.assert_equal(
np.array([[1, 1, 1, 1, 1]], dtype=np.float32), x.grad.numpy() np.array([[1, 1, 1, 1, 1]], dtype=np.float32), x.grad.numpy()
) )
@@ -363,10 +357,11 @@ def test_Broadcast():
x_np = np.random.rand(3, 3, 1).astype("float32") x_np = np.random.rand(3, 3, 1).astype("float32")
x = mge.Tensor(x_np) x = mge.Tensor(x_np)


grad = Grad().wrt(x, callback=save_to(x))
y = F.broadcast_to(x, (3, 3, 10))
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
y = F.broadcast_to(x, (3, 3, 10))
grad(y, F.ones_like(y))


grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((3, 3, 1), dtype=np.float32) * 10, x.grad.numpy()) np.testing.assert_equal(np.ones((3, 3, 1), dtype=np.float32) * 10, x.grad.numpy())




@@ -374,10 +369,11 @@ def test_interpolate_fastpath():
x_np = np.random.rand(3, 3, 32, 32).astype("float32") x_np = np.random.rand(3, 3, 32, 32).astype("float32")
x = mge.Tensor(x_np) x = mge.Tensor(x_np)


grad = Grad().wrt(x, callback=save_to(x))
y = F.vision.interpolate(x, size=(16, 16), mode="bilinear")
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
y = F.vision.interpolate(x, size=(16, 16), mode="bilinear")
grad(y, F.ones_like(y))


grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones(x_np.shape, dtype=np.float32) / 4, x.grad.numpy()) np.testing.assert_equal(np.ones(x_np.shape, dtype=np.float32) / 4, x.grad.numpy())




@@ -385,10 +381,11 @@ def test_Reduce_sum():
x_np = np.random.rand(3, 3).astype("float32") x_np = np.random.rand(3, 3).astype("float32")
x = mge.Tensor(x_np) x = mge.Tensor(x_np)


grad = Grad().wrt(x, callback=save_to(x))
y = x.sum(axis=0)
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
y = x.sum(axis=0)
grad(y, F.ones_like(y))


grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((3, 3), dtype=np.float32), x.grad.numpy()) np.testing.assert_equal(np.ones((3, 3), dtype=np.float32), x.grad.numpy())




@@ -396,10 +393,11 @@ def test_Reduce_mean():
x_np = np.random.rand(3, 3).astype("float32") x_np = np.random.rand(3, 3).astype("float32")
x = mge.Tensor(x_np) x = mge.Tensor(x_np)


grad = Grad().wrt(x, callback=save_to(x))
y = x.mean(axis=0)
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
y = x.mean(axis=0)
grad(y, F.ones_like(y))


grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((3, 3), dtype=np.float32) / 3, x.grad.numpy()) np.testing.assert_equal(np.ones((3, 3), dtype=np.float32) / 3, x.grad.numpy())




@@ -407,21 +405,21 @@ def test_addAxis():
x_np = np.random.rand(3, 3).astype("float32") x_np = np.random.rand(3, 3).astype("float32")
x = mge.Tensor(x_np) x = mge.Tensor(x_np)


grad = Grad().wrt(x, callback=save_to(x))
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
refs = {}


refs = {}

def f(x):
x = x * 1
y = F.expand_dims(x, [2, 3])
refs["x"] = TensorWeakRef(x)
return y
def f(x):
x = x * 1
y = F.expand_dims(x, [2, 3])
refs["x"] = TensorWeakRef(x)
return y


y = f(x)
for _, r in refs.items():
assert r() is None
y = f(x)
for _, r in refs.items():
assert r() is None
grad(y, F.ones_like(y))


grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((3, 3), dtype=np.float32), x.grad.numpy()) np.testing.assert_equal(np.ones((3, 3), dtype=np.float32), x.grad.numpy())




@@ -429,21 +427,21 @@ def test_removeAxis():
x_np = np.random.rand(3, 3, 1, 1).astype("float32") x_np = np.random.rand(3, 3, 1, 1).astype("float32")
x = mge.Tensor(x_np) x = mge.Tensor(x_np)


grad = Grad().wrt(x, callback=save_to(x))
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
refs = {}


refs = {}
def f(x):
x = x * 1
y = F.squeeze(x, [2, 3])
refs["x"] = TensorWeakRef(x)
return y


def f(x):
x = x * 1
y = F.squeeze(x, [2, 3])
refs["x"] = TensorWeakRef(x)
return y

y = f(x)
for _, r in refs.items():
assert r() is None
y = f(x)
for _, r in refs.items():
assert r() is None
grad(y, F.ones_like(y))


grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((3, 3, 1, 1), dtype=np.float32), x.grad.numpy()) np.testing.assert_equal(np.ones((3, 3, 1, 1), dtype=np.float32), x.grad.numpy())




@@ -452,11 +450,14 @@ def test_dot():
x = mge.Tensor(x) x = mge.Tensor(x)
u = F.ones((2,)) u = F.ones((2,))
v = F.ones((2,)) v = F.ones((2,))
grad = Grad().wrt(x, callback=save_to(x))


def f(x):
return F.dot(u, F.matmul(x, v))
with Grad() as grad:
grad.wrt(x, callback=save_to(x))

def f(x):
return F.dot(u, F.matmul(x, v))

y = f(x)
grad(y, F.ones_like(y))


y = f(x)
grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((2, 2), dtype=np.float32), x.grad.numpy()) np.testing.assert_equal(np.ones((2, 2), dtype=np.float32), x.grad.numpy())

+ 66
- 55
imperative/python/test/unit/functional/test_functional.py View File

@@ -267,25 +267,27 @@ def _gen_roi_inp():


def test_roi_align(): def test_roi_align():
inp_feat, rois = _gen_roi_inp() inp_feat, rois = _gen_roi_inp()
grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat))

output_shape = (7, 7)
out_feat = F.vision.roi_align(
inp_feat,
rois,
output_shape=output_shape,
mode="average",
spatial_scale=1.0 / 4,
sample_points=2,
aligned=True,
)
assert make_shape_tuple(out_feat.shape) == (
rois.shape[0],
inp_feat.shape[1],
*output_shape,
)
with Grad() as grad:
grad.wrt(inp_feat, callback=_save_to(inp_feat))

output_shape = (7, 7)
out_feat = F.vision.roi_align(
inp_feat,
rois,
output_shape=output_shape,
mode="average",
spatial_scale=1.0 / 4,
sample_points=2,
aligned=True,
)
assert make_shape_tuple(out_feat.shape) == (
rois.shape[0],
inp_feat.shape[1],
*output_shape,
)

grad(out_feat, tensor(F.ones_like(out_feat)))


grad(out_feat, tensor(F.ones_like(out_feat)))
assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape) assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape)




@@ -307,20 +309,23 @@ def _gen_correlation(random=True, constant=1, image_shape=(2, 1, 160, 160)):
def test_correlation(): def test_correlation():
##test case 0 check the grad shape ##test case 0 check the grad shape
data1, data2 = _gen_correlation() data1, data2 = _gen_correlation()
grad = Grad().wrt(data1, callback=_save_to(data1))


out_feat = F.vision.correlation(
data1,
data2,
kernel_size=5,
max_displacement=4,
stride1=2,
stride2=2,
pad_size=2,
is_multiply=True,
)
with Grad() as grad:
grad.wrt(data1, callback=_save_to(data1))

out_feat = F.vision.correlation(
data1,
data2,
kernel_size=5,
max_displacement=4,
stride1=2,
stride2=2,
pad_size=2,
is_multiply=True,
)

grad(out_feat, tensor(F.ones_like(out_feat)))


grad(out_feat, tensor(F.ones_like(out_feat)))
assert make_shape_tuple(data1.grad.shape) == make_shape_tuple(data1.shape) assert make_shape_tuple(data1.grad.shape) == make_shape_tuple(data1.shape)


##test case 1 from https://github.com/NVIDIA/flownet2-pytorch/issues/194 ##test case 1 from https://github.com/NVIDIA/flownet2-pytorch/issues/194
@@ -391,32 +396,36 @@ def test_correlation():


def test_roi_pooling(): def test_roi_pooling():
inp_feat, rois = _gen_roi_inp() inp_feat, rois = _gen_roi_inp()
grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat))
output_shape = (7, 7)
out_feat = F.vision.roi_pooling(
inp_feat, rois, output_shape=output_shape, mode="max", scale=1.0 / 4,
)
assert make_shape_tuple(out_feat.shape) == (
rois.shape[0],
inp_feat.shape[1],
*output_shape,
)
with Grad() as grad:
grad.wrt(inp_feat, callback=_save_to(inp_feat))
output_shape = (7, 7)
out_feat = F.vision.roi_pooling(
inp_feat, rois, output_shape=output_shape, mode="max", scale=1.0 / 4,
)
assert make_shape_tuple(out_feat.shape) == (
rois.shape[0],
inp_feat.shape[1],
*output_shape,
)

grad(out_feat, tensor(F.ones_like(out_feat)))


grad(out_feat, tensor(F.ones_like(out_feat)))
assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape) assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape)




def test_adaptive_avg_pool2d(): def test_adaptive_avg_pool2d():
inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4)) inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4))
oshp = (2, 2) oshp = (2, 2)
grad = Grad().wrt(inp, callback=_save_to(inp))
outp = F.adaptive_avg_pool2d(inp, oshp,)
assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,)
np.testing.assert_equal(
outp.numpy(), np.array([[[[2.5, 4.5], [10.5, 12.5]]]], dtype=np.float32)
)
with Grad() as grad:
grad.wrt(inp, callback=_save_to(inp))
outp = F.adaptive_avg_pool2d(inp, oshp,)
assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,)
np.testing.assert_equal(
outp.numpy(), np.array([[[[2.5, 4.5], [10.5, 12.5]]]], dtype=np.float32)
)

grad(outp, tensor(F.ones_like(outp)))


grad(outp, tensor(F.ones_like(outp)))
assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape) assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape)
np.testing.assert_equal( np.testing.assert_equal(
inp.grad.numpy(), inp.grad.numpy(),
@@ -439,14 +448,16 @@ def test_adaptive_avg_pool2d():
def test_adaptive_max_pool2d(): def test_adaptive_max_pool2d():
inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4)) inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4))
oshp = (2, 2) oshp = (2, 2)
grad = Grad().wrt(inp, callback=_save_to(inp))
outp = F.adaptive_max_pool2d(inp, oshp,)
assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,)
np.testing.assert_equal(
outp.numpy(), np.array([[[[5, 7], [13, 15]]]], dtype=np.float32)
)
with Grad() as grad:
grad.wrt(inp, callback=_save_to(inp))
outp = F.adaptive_max_pool2d(inp, oshp,)
assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,)
np.testing.assert_equal(
outp.numpy(), np.array([[[[5, 7], [13, 15]]]], dtype=np.float32)
)

grad(outp, tensor(F.ones_like(outp)))


grad(outp, tensor(F.ones_like(outp)))
assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape) assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape)
np.testing.assert_equal( np.testing.assert_equal(
inp.grad.numpy(), inp.grad.numpy(),


+ 1
- 1
imperative/python/test/unit/functional/test_tensor.py View File

@@ -351,7 +351,7 @@ def test_expand_dims_for_scalar():


for axis in [1, -2, (1, 2), (-2, -3)]: for axis in [1, -2, (1, 2), (-2, -3)]:
np.testing.assert_raises(np.AxisError, np.expand_dims, x, axis) np.testing.assert_raises(np.AxisError, np.expand_dims, x, axis)
np.testing.assert_raises(AssertionError, F.expand_dims, xx, axis)
np.testing.assert_raises(RuntimeError, F.expand_dims, xx, axis)




@pytest.mark.parametrize("is_varnode", [True, False]) @pytest.mark.parametrize("is_varnode", [True, False])


+ 79
- 21
imperative/python/test/unit/jit/test_tracing.py View File

@@ -9,6 +9,7 @@
import inspect import inspect
import io import io
import itertools import itertools
import random
from tempfile import mkstemp from tempfile import mkstemp


import numpy as np import numpy as np
@@ -25,7 +26,7 @@ from megengine.core.ops import builtin as ops
from megengine.core.ops.builtin import Elemwise from megengine.core.ops.builtin import Elemwise
from megengine.core.tensor.utils import isscalar from megengine.core.tensor.utils import isscalar
from megengine.functional import exp, log from megengine.functional import exp, log
from megengine.jit import GraphOptimizationConfig, exclude_from_trace, trace
from megengine.jit import GraphOptimizationConfig, TraceError, exclude_from_trace, trace
from megengine.module import Module from megengine.module import Module
from megengine.random import normal, uniform from megengine.random import normal, uniform
from megengine.utils.naming import AutoNaming from megengine.utils.naming import AutoNaming
@@ -464,36 +465,92 @@ def test_trace_warp_perspective():
f(x, M) f(x, M)




def test_raise_on_trace():
step_count = 0
catch_count = 0
bad_step = 10
@pytest.mark.parametrize(
"normal_expr, mismatch_expr, reason",
[
("a + b + c", "a + b - c", "operator mismatch"),
("a + b + 1", "a + b + 2", "tensors not equals"),
("((a + b), (b + c))[0]", "a + b", "mismature end"),
("a + b + c", "c + (a + b)", "expect internal node, got external"),
("c + (a + b)", "a + b + c", "expect external node, got internal"),
("a + b + c", "a + b + c + c", "too many instructions"),
("((a + b), (b + c))[1]", "((a + b), (b + c))[0]", "data unreadable"),
("((a + b), (b + c))[1] + a", "((a + b), (b + c))[0] + a", "input id mismatch"),
],
)
def test_trace_mismatch(normal_expr, mismatch_expr, reason):
a = tensor([1, 2, 3, 4])
b = tensor([5, 6, 7, 8])
c = tensor([9, 0, 1, 2])

mismatch = False

@trace(symbolic=True)
def fn(a, b, c):
if not mismatch:
result = eval(normal_expr)
else:
result = eval(mismatch_expr)
return result

for i in range(20):
try:
d = fn(a, b, c)
except TraceError as e:
assert mismatch
assert str(e) == "trace error because {}".format(reason)
except:
pytest.fail("unexpected trace error")
else:
assert not mismatch
np.testing.assert_equal(d.numpy(), eval(normal_expr).numpy())
mismatch = random.random() > 0.8


class CatchMe(Exception):
pass


def test_exception_in_trace():
a = tensor([1, 2, 3, 4]) a = tensor([1, 2, 3, 4])
b = tensor([5, 6, 7, 8]) b = tensor([5, 6, 7, 8])
c = tensor([9, 0, 1, 2]) c = tensor([9, 0, 1, 2])


@trace
def add_abc(a, b, c):
ps = a + b
result = ps + c
if step_count == bad_step:
raise CatchMe("catch me")
mismatch = False

exc = Exception()

@trace(symbolic=True)
def fn(a, b, c):
result = a + b
if not mismatch:
result += c
else:
raise exc
return result return result


for i in range(100):
for i in range(20):
try: try:
d = add_abc(a, b, c)
except CatchMe as e:
catch_count += 1
d = fn(a, b, c)
except TraceError as e:
pytest.fail("unexpected trace error")
except Exception as e:
assert mismatch
assert e is exc
else: else:
assert not mismatch
np.testing.assert_equal(d.numpy(), (a + b + c).numpy()) np.testing.assert_equal(d.numpy(), (a + b + c).numpy())
step_count += 1
mismatch = random.random() > 0.8


assert catch_count == 1

def test_graph_error():
a = tensor(np.arange(8).reshape((2, 4)))
b = tensor(np.arange(8).reshape((2, 4)))

@trace(symbolic=True)
def fn(a, b):
return a + b

fn(a, b)
with pytest.raises(RuntimeError):
fn(a, b.transpose())
fn(a, b)




@pytest.mark.parametrize("trace_mode", [False, True]) @pytest.mark.parametrize("trace_mode", [False, True])
@@ -653,9 +710,10 @@ def test_trace_jit_config():


x = tensor(2) x = tensor(2)
y = func(x) y = func(x)
func._compile()
y = func(x)
# func._compile()


options = func._graph.options
options = func._trace.options
mapping = {None: 0, False: 1, True: 2} mapping = {None: 0, False: 1, True: 2}
assert options.graph_opt.jit == 0 assert options.graph_opt.jit == 0
assert options.graph_opt.jit_config.fuse_dimshuffle == mapping[fuse_dimshuffle] assert options.graph_opt.jit_config.fuse_dimshuffle == mapping[fuse_dimshuffle]


+ 16
- 12
imperative/python/test/unit/quantization/test_fake_quant.py View File

@@ -82,9 +82,10 @@ def test_tqt():
x = mge.tensor(x, dtype="float32") x = mge.tensor(x, dtype="float32")
s = mge.tensor(s, dtype="float32") s = mge.tensor(s, dtype="float32")
g_y = mge.tensor(g_y, dtype="float32") g_y = mge.tensor(g_y, dtype="float32")
grad = Grad().wrt(x, s, callback=cb)
y = tqt_forward(-127, 127, x, s)
grad(y, g_y)
with Grad() as grad:
grad.wrt(x, s, callback=cb)
y = tqt_forward(-127, 127, x, s)
grad(y, g_y)
g_x, g_s = g g_x, g_s = g


np.testing.assert_allclose(y.numpy(), y_np, rtol=1e-5, atol=1e-5) np.testing.assert_allclose(y.numpy(), y_np, rtol=1e-5, atol=1e-5)
@@ -131,14 +132,16 @@ def test_fakequant():


# test backward # test backward
x = tensor(inp_data, dtype=np.float32) x = tensor(inp_data, dtype=np.float32)
grad = Grad().wrt(x, callback=_save_to(x))
y = fake_quant_tensor(x, qparams)
grad(y, tensor(F.ones_like(x)))
with Grad() as grad:
grad.wrt(x, callback=_save_to(x))
y = fake_quant_tensor(x, qparams)
grad(y, tensor(F.ones_like(x)))


x1 = tensor(inp_data, dtype=np.float32) x1 = tensor(inp_data, dtype=np.float32)
grad = Grad().wrt(x1, callback=_save_to(x1))
y1 = fake_quant_tensor_gt(x1, scale, zero_point, qmin, qmax)
grad(y1, tensor(F.ones_like(x1)))
with Grad() as grad:
grad.wrt(x1, callback=_save_to(x1))
y1 = fake_quant_tensor_gt(x1, scale, zero_point, qmin, qmax)
grad(y1, tensor(F.ones_like(x1)))


assert np.allclose(x.grad.numpy(), x1.grad.numpy()) assert np.allclose(x.grad.numpy(), x1.grad.numpy())
assert make_shape_tuple(x.grad.shape) == make_shape_tuple(x1.grad.shape) assert make_shape_tuple(x.grad.shape) == make_shape_tuple(x1.grad.shape)
@@ -237,9 +240,10 @@ def test_lsq():
grad_s = mge.tensor(grad_s, dtype="float32") grad_s = mge.tensor(grad_s, dtype="float32")


g_y = mge.tensor(g_y, dtype="float32") g_y = mge.tensor(g_y, dtype="float32")
grad = Grad().wrt(x, s, callback=cb)
y = lsq_forward(-127, 127, x, s, zero_point, grad_s)
grad(y, g_y)
with Grad() as grad:
grad.wrt(x, s, callback=cb)
y = lsq_forward(-127, 127, x, s, zero_point, grad_s)
grad(y, g_y)
g_x, g_s = g g_x, g_s = g


np.testing.assert_allclose(y.numpy(), y_np, rtol=1e-7, atol=1e-7) np.testing.assert_allclose(y.numpy(), y_np, rtol=1e-7, atol=1e-7)


+ 4
- 3
imperative/python/test/unit/random/test_rng.py View File

@@ -430,9 +430,10 @@ def test_ShuffleRNG():
n, m = 6, 3 n, m = 6, 3
arr = np.arange(n * m) arr = np.arange(n * m)
out0 = Tensor(arr, dtype="float32") out0 = Tensor(arr, dtype="float32")
grad = Grad().wrt(out0, callback=cb)
random.shuffle(out0)
grad(out0, F.ones_like(out0))
with Grad() as grad:
grad.wrt(out0, callback=cb)
random.shuffle(out0)
grad(out0, F.ones_like(out0))
m1 = RNG(seed=111, device="xpu0") m1 = RNG(seed=111, device="xpu0")
m2 = RNG(seed=111, device="xpu1") m2 = RNG(seed=111, device="xpu1")
m3 = RNG(seed=222, device="xpu0") m3 = RNG(seed=222, device="xpu0")


Loading…
Cancel
Save