Browse Source

refactor(dispatch): switch to new dispatch system

GitOrigin-RevId: 32dd49a23a
tags/v1.8.0
Megvii Engine Team 3 years ago
parent
commit
0bdd0b1467
46 changed files with 1622 additions and 3613 deletions
  1. +11
    -23
      imperative/python/megengine/autodiff/grad_manager.py
  2. +98
    -31
      imperative/python/megengine/core/autodiff/grad.py
  3. +1
    -13
      imperative/python/megengine/core/tensor/array_method.py
  4. +3
    -14
      imperative/python/megengine/core/tensor/indexing.py
  5. +2
    -5
      imperative/python/megengine/core/tensor/megbrain_graph.py
  6. +1
    -10
      imperative/python/megengine/core/tensor/utils.py
  7. +8
    -28
      imperative/python/megengine/distributed/functional.py
  8. +0
    -3
      imperative/python/megengine/distributed/helper.py
  9. +0
    -25
      imperative/python/megengine/experimental/autograd.py
  10. +0
    -3
      imperative/python/megengine/functional/inplace.py
  11. +1
    -3
      imperative/python/megengine/functional/math.py
  12. +0
    -1
      imperative/python/megengine/functional/nn.py
  13. +0
    -9
      imperative/python/megengine/functional/tensor.py
  14. +1
    -13
      imperative/python/megengine/jit/__init__.py
  15. +161
    -962
      imperative/python/megengine/jit/tracing.py
  16. +35
    -6
      imperative/python/megengine/module/module.py
  17. +5
    -0
      imperative/python/megengine/random/rng.py
  18. +20
    -7
      imperative/python/megengine/tensor.py
  19. +0
    -3
      imperative/python/megengine/traced_module/__init__.py
  20. +0
    -3
      imperative/python/megengine/traced_module/checker.py
  21. +11
    -2
      imperative/python/megengine/traced_module/expr.py
  22. +0
    -2
      imperative/python/megengine/traced_module/module_tracer.py
  23. +2
    -6
      imperative/python/megengine/traced_module/traced_module.py
  24. +5
    -0
      imperative/python/megengine/utils/profiler.py
  25. +59
    -625
      imperative/python/src/grad.cpp
  26. +15
    -141
      imperative/python/src/grad.h
  27. +0
    -43
      imperative/python/src/grad_info.h
  28. +237
    -164
      imperative/python/src/grad_override.cpp
  29. +0
    -245
      imperative/python/src/intrusive_list.h
  30. +0
    -42
      imperative/python/src/module_trace.cpp
  31. +41
    -1
      imperative/python/src/module_trace.h
  32. +2
    -7
      imperative/python/src/ops.cpp
  33. +493
    -541
      imperative/python/src/tensor.cpp
  34. +38
    -216
      imperative/python/src/tensor.h
  35. +0
    -63
      imperative/python/src/trace.cpp
  36. +0
    -28
      imperative/python/src/trace.h
  37. +0
    -49
      imperative/python/src/trace_info.h
  38. +0
    -14
      imperative/python/test/conftest.py
  39. +1
    -1
      imperative/python/test/integration/test_trace_dump.py
  40. +43
    -8
      imperative/python/test/unit/autodiff/test_grad_manger.py
  41. +162
    -161
      imperative/python/test/unit/core/test_autodiff.py
  42. +66
    -55
      imperative/python/test/unit/functional/test_functional.py
  43. +1
    -1
      imperative/python/test/unit/functional/test_tensor.py
  44. +79
    -21
      imperative/python/test/unit/jit/test_tracing.py
  45. +16
    -12
      imperative/python/test/unit/quantization/test_fake_quant.py
  46. +4
    -3
      imperative/python/test/unit/random/test_rng.py

+ 11
- 23
imperative/python/megengine/autodiff/grad_manager.py View File

@@ -28,9 +28,6 @@ class AttachSpec:
__slots__ = "tensor", "callbacks"


_global_priority = 0


class GradManager:
r"""GradManager computes gradients or more generally, vector-Jacobian product, by reverse mode
automatic differentiation (a.k.a. back propagation).
@@ -127,7 +124,6 @@ class GradManager:
self._grad = None
self._after_backward_callback = []
self._gradients = {}
self._priority = None

def attached_tensors(self):
r"""Return attached tensor list from :meth:`attach`."""
@@ -299,31 +295,25 @@ class GradManager:
tensor.grad = grad
else:
tensor.grad += grad
if tensor._isscalar() and tensor.grad is not None:
tensor.grad._setscalar()
finally:
self.release()
backwarding_grad_manager = cache
set_option("record_computing_path", 1)
pop_scope("backward")
set_option("record_computing_path", 1)
pop_scope("backward")

def record(self):
r"""Start recording operations

After this call, you will be able to call :meth:`backward`.
"""
global _global_priority
if self._recording:
raise RuntimeError("already recording")
grad = Grad()
self._recording = True
self._grad = grad
grad.__enter__()
for spec in self._attach_specs.values():
self._do_record(spec)
if self._priority is None:
grad._priority = _global_priority
_global_priority -= 1
grad.__enter__()

def _do_record(self, spec):
tensor = spec.tensor()
@@ -331,6 +321,8 @@ class GradManager:
return

def callback(grad, callbacks=spec.callbacks):
from ..functional import ones_like

for cb in callbacks:
grad = cb(tensor, grad)
self._gradients[id(tensor)] = grad
@@ -343,14 +335,11 @@ class GradManager:

After this call, you will not be able to call :meth:`backward`.
"""
global _global_priority
if self._grad is not None:
self._grad.__exit__(None, None, None)
self._grad = None
self._recording = False
self._gradients = dict()
if self._priority is None:
_global_priority += 1

def __enter__(self):
self.record()
@@ -382,15 +371,14 @@ class GradManagerGroup:
__ror__ = merge_with

def __enter__(self):
global _global_priority
_global_priority += 1
Grad.stack.append([])
Grad.begin_group()
for gm in self._gms:
gm._priority = _global_priority
gm.record()
assert gm._grad is not None
Grad.end_group()

def __exit__(self, exc_type, exc_val, exc_tb):
global _global_priority
_global_priority -= 1
for gm in self._gms:
for gm in reversed(self._gms):
gm.release()
gm._priority = None
assert gm._grad is None

+ 98
- 31
imperative/python/megengine/core/autodiff/grad.py View File

@@ -6,17 +6,9 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools
import heapq
import itertools
import typing
import weakref

import numpy as np

from .._imperative_rt import core2, ops
from ..ops.builtin import Elemwise, OpDef, RemoteSend
from ..ops.special import Const
from .._imperative_rt import core2

_grad_count = 0
_grad_manager_dict = weakref.WeakValueDictionary()
@@ -36,6 +28,10 @@ class GradKey(core2.GradKey):


class Grad:
stack = []
grouping = False
key2grad = weakref.WeakValueDictionary()

def __init__(self, name=None):
global _grad_count
if name is None:
@@ -43,15 +39,9 @@ class Grad:
_grad_count += 1
self._refkeeper = []
self._impl = GradKey(name)
Grad.key2grad[self._impl] = self
_grad_manager_dict[self._name] = self

@property
def _priority(self):
return self._impl.priority

@_priority.setter
def _priority(self, priority):
self._impl.priority = priority
self._group = [weakref.ref(self)]

@property
def _name(self):
@@ -70,33 +60,80 @@ class Grad:

if not isinstance(ys, Sequence):
ys = [ys]

if not isinstance(dys, Sequence):
dys = [dys]

group = [ref() for ref in self._group]

for grad in group:
if grad is self:
continue
grad.suppress()

self._impl.backward(ys, dys)

for grad in group:
if grad is self:
continue
grad.resume()

self._refkeeper = None
return None

def __enter__(self):
ref = weakref.ref(self)
self._impl.enter()
if Grad.grouping:
group = Grad.stack[-1]
self._group = group
group.append(ref)
else:
Grad.stack.append(self._group)
return self

def __exit__(self, _1, _2, _3):
self._impl.exit()
self._refkeeper = None
del self._impl


class Function(ops.PyOpBase):
del Grad.key2grad[self._impl]
self._impl = None
self._group.remove(weakref.ref(self))
if len(self._group) == 0:
Grad.stack.remove(self._group)

@staticmethod
def begin_group():
assert not Grad.grouping
Grad.grouping = True

@staticmethod
def end_group():
group = Grad.stack[-1]
assert len(group) > 0
assert Grad.grouping
Grad.grouping = False

def suppress(self):
if self._impl is not None:
self._impl.suppress()

def resume(self):
if self._impl is not None:
self._impl.resume()


class Function:
r"""Defines a block of operations with customizable differentiation.
The computation should be defined in ``forward`` method, with gradient
computation defined in ``backward`` method.
Each instance of ``Function`` should be used only once during forwardding.
Examples:
.. code-block::
class Sigmoid(Function):
def forward(self, x):
y = 1 / (1 + F.exp(-x))
@@ -115,7 +152,7 @@ class Function(ops.PyOpBase):

Returns:
a tuple of Tensor or a single Tensor.
Note:
* This method should return a tuple of Tensor or a single Tensor representing the output
of the function.
@@ -128,7 +165,7 @@ class Function(ops.PyOpBase):

Args:
output_grads: gradients of outputs that are returned by :meth:`forward`.
Note:
* In case when some tensors of outputs are not related to loss function, the corresponding
values in ``output_grads`` would be ``None``.
@@ -148,10 +185,40 @@ class Function(ops.PyOpBase):
return self._default_rule(*args), self.backward

def __call__(self, *args):
ret = core2.apply(self, *args)
for arg in args:
if not isinstance(arg, core2.Tensor):
raise TypeError(
"op Function expect type Tensor as inputs, got {}".format(type(arg))
)

grad_key = core2.get_grad_key(args)
if grad_key is None:
return self._default_rule(*args)

grad = Grad.key2grad[grad_key]
group = [ref() for ref in grad._group]

for grad in group:
grad.suppress()
outputs, backward = self._grad_rule(*args)
for grad in reversed(group):
grad.resume()

def normalized_backward(*output_grads):
input_grads = backward(*output_grads)
if isinstance(input_grads, core2.Tensor) or input_grads is None:
input_grads = (input_grads,)
return input_grads

if self.__single_output:
(ret,) = ret
return ret
outputs = (outputs,)
for grad in reversed(group):
if grad._impl is None:
continue
outputs = core2.set_grad(grad._impl, normalized_backward, args, outputs)
if self.__single_output:
(outputs,) = outputs
return outputs

def __getstate__(self):
return self.__dict__


+ 1
- 13
imperative/python/megengine/core/tensor/array_method.py View File

@@ -26,7 +26,6 @@ from .utils import (
convert_inputs,
isscalar,
make_shape_tuple,
setscalar,
)

_ElwMod = builtin.Elemwise.Mode
@@ -34,14 +33,7 @@ _ElwMod = builtin.Elemwise.Mode

def _elwise_apply(args, mode):
op = builtin.Elemwise(mode)
_isscalar = True
for i in args:
if isscalar(i) == False:
_isscalar = False
break
(result,) = apply(op, *args)
if _isscalar:
setscalar(result)
return result


@@ -203,8 +195,6 @@ def _remove_axis(inp: Tensor, axis) -> Tensor:

op = builtin.RemoveAxis(axis=axis)
(result,) = apply(op, inp)
if len(axis) == inp.ndim:
setscalar(result)
return result


@@ -221,6 +211,7 @@ def _reduce(mode):

op = builtin.Reduce(mode=mode, axis=0)
(result,) = apply(op, data)
result = _remove_axis(result, 0)
elif isinstance(axis, collections.abc.Iterable):
axis = _normalize_axis(self.ndim, axis, reverse=True)
for ai in axis:
@@ -239,8 +230,6 @@ def _reduce(mode):
if self.dtype == np.bool_:
if mode in ["min", "max"]:
result = result.astype("bool")
if axis is None or self.ndim == 1:
setscalar(result)
return result

return f
@@ -457,7 +446,6 @@ class ArrayMethodMixin(abc.ABC):
len(args) == 0
), "transpose for scalar does not accept additional args"
ret = self.to(self.device)
setscalar(ret)
return ret
if not args:
args = range(self.ndim)[::-1]


+ 3
- 14
imperative/python/megengine/core/tensor/indexing.py View File

@@ -111,7 +111,6 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
if not isinstance(tuple_val, tuple):
tuple_val = (tuple_val,)
ndim_indexed = 0
ndim_indexed_scalar = 0
for i in tuple_val:
if not i is Ellipsis:
ndim_indexed += (
@@ -119,14 +118,6 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
if hasattr(i, "dtype") and i.dtype == np.bool_ and hasattr(i, "ndim")
else 1
)
if isscalar(i):
ndim_indexed_scalar += 1
ret_scalar = False
try:
ret_scalar = ndim_indexed_scalar == inp.ndim
except ValueError:
# inp.ndim is unknown
pass
else:
if ndim_indexed > inp.ndim:
raise IndexError(
@@ -221,7 +212,7 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
items.append(item)
if new_axes:
raise IndexError("newaxis is not allowed here")
return inp, tensors, items, use_subtensor, ret_scalar
return inp, tensors, items, use_subtensor


def try_condtake(tensor, index):
@@ -247,14 +238,12 @@ def getitem(tensor, index):
try_result = try_condtake(tensor, index)
if len(try_result) == 2:
return try_result[0]
tensor, tensors, items, use_subtensor, ret_scalar = unpack_getitem(tensor, index)
tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index)
if use_subtensor:
op = builtin.Subtensor(items=items)
else:
op = builtin.IndexingMultiAxisVec(items=items)
(result,) = apply(op, tensor, *tensors)
if ret_scalar:
result._setscalar()
return result


@@ -266,7 +255,7 @@ def setitem(tensor, index, value):
tensor = tensor.reshape(-1)
if not isinstance(value, (Tensor, SymbolVar)):
(value,) = Const(value, dtype=tensor.dtype, device=tensor.device)(tensor)
tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index)
tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index)
if use_subtensor:
op = builtin.Subtensor(items=items)
else:


+ 2
- 5
imperative/python/megengine/core/tensor/megbrain_graph.py View File

@@ -17,6 +17,7 @@ import numpy as np

from .. import _imperative_rt
from .._imperative_rt import GraphOptimizeOptions, SerializationFormat
from .._imperative_rt.core2 import apply
from .._wrap import as_device
from ..ops.builtin import OpDef

@@ -126,9 +127,8 @@ class Graph(_imperative_rt.ComputingGraph):


class VarNode:
def __init__(self, node: _imperative_rt.VarNode, isscalar=False):
def __init__(self, node: _imperative_rt.VarNode):
self._node = node
self._isscalar = isscalar
if hasattr(self.graph, "_var_cache"):
self.graph._var_cache[node] = self

@@ -530,9 +530,6 @@ def _unwrap(x):


def apply_normal_varnode(op: OpDef, *args: VarNode):
# for PyOp like RemoteSend/Recv
if getattr(op, "op", None):
op = op.op
outputs = _imperative_rt.invoke_op(op, _unwrap(args))
return _wrap(outputs)



+ 1
- 10
imperative/python/megengine/core/tensor/utils.py View File

@@ -51,10 +51,7 @@ def concatenate(inputs, axis=0, *, device=None):
def astype(x, dtype):
dtype = np.dtype(dtype)
if not is_dtype_equal(x.dtype, dtype):
isscalar = x._isscalar()
(x,) = apply(builtin.TypeCvt(dtype=dtype), x)
if isscalar:
x._setscalar()
return x


@@ -129,13 +126,6 @@ def isscalar(x):
return np.isscalar(x)


def setscalar(x):
if isinstance(x, (Tensor, SymbolVar)):
x._setscalar()
else:
raise NotImplementedError("Unsupport type {}".format(type(x)))


def astensor1d(x, *reference, dtype=None, device=None):
"""Convert something to 1D tensor. Support following types

@@ -237,6 +227,7 @@ for name, mode in [
("**", "pow"),
("max", "max"),
("additive", "add"),
("exp", "EXP"),
]:
_opr_map[(name, 2)] = builtin.Elemwise(mode=mode)



+ 8
- 28
imperative/python/megengine/distributed/functional.py View File

@@ -13,7 +13,7 @@ import numpy as np
from ..core._imperative_rt.core2 import apply
from ..core.autodiff.grad import Function, _grad_manager_dict
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
from ..core.tensor.utils import isscalar, setscalar
from ..core.tensor.utils import isscalar
from ..device import get_default_device, what_is_xpu
from ..tensor import Tensor
from . import group
@@ -72,15 +72,6 @@ def collective_comm(inp, mode, group, device):
)
(result,) = apply(op, inp)
# assume all workers have homogeneous shape
if mode in (
CollectiveComm.Mode.REDUCE_SUM,
CollectiveComm.Mode.BROADCAST,
CollectiveComm.Mode.ALL_REDUCE_SUM,
CollectiveComm.Mode.ALL_REDUCE_MAX,
CollectiveComm.Mode.ALL_REDUCE_MIN,
):
if isscalar(inp):
setscalar(result)
return result


@@ -190,8 +181,7 @@ def reduce_sum(
# Rank 0 # output: None
# Rank 1 # output: Tensor([1])
"""
op = _ReduceSum(group, device)
(out,) = apply(op, inp)
out = _ReduceSum(group, device)(inp)

if group.rank == 0:
return out
@@ -258,8 +248,7 @@ def broadcast(

_bcast_tracer_state(group, inp)

op = _Broadcast(group, device)
(out,) = apply(op, inp)
out = _Broadcast(group, device)(inp)
return out


@@ -604,8 +593,7 @@ def gather(
inp.shape
)

op = _Gather(group, device)
(out,) = apply(op, inp)
out = _Gather(group, device)(inp)

if group.rank == 0:
if axis == 0:
@@ -708,8 +696,7 @@ def scatter(
+ [_ for _ in range(axis + 1, inp.ndim + 1)]
)
inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape)
op = _Scatter(group, device)
(out,) = apply(op, inp)
out = _Scatter(group, device)(inp)
return out


@@ -832,7 +819,7 @@ class _RemoteRecv(Function):
self.op = op

def forward(self, dummy):
return apply(self.op, dummy)
return apply(self.op, dummy)[0]

def backward(self, grad):
get_client().bcast_val(grad is not None, self.op.key, 2)
@@ -871,7 +858,7 @@ def remote_send(inp: Tensor, dest_rank: int):
op.addr, op.port = get_mm_server_addr()
op.rank_to = dest_rank
op.backend = _backend()
(out,) = apply(_RemoteSend(op), inp)
out = _RemoteSend(op)(inp)

_save_output_for_autodiff(inp, out)

@@ -912,11 +899,6 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor
inp = Tensor(0, device=device)
_bcast_tracer_state(group, inp)

_isscalar = False
if len(shape) == 0:
shape = (1,)
_isscalar = True

op = RemoteRecv()
op.key = group.key
op.cn = device
@@ -926,7 +908,5 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor
op.rank_from = src_rank
op.backend = _backend()

(ret,) = apply(_RemoteRecv(op), inp)
if _isscalar:
setscalar(ret)
ret = _RemoteRecv(op)(inp)
return ret

+ 0
- 3
imperative/python/megengine/distributed/helper.py View File

@@ -67,9 +67,6 @@ def param_pack_split(inp: Tensor, offsets: list, shapes: list):
op.offsets = offsets
op.shapes = [s or (1,) for s in shapes]
outputs = apply(op, inp)
for s, x in zip(shapes, outputs):
if not s:
x._setscalar()
return outputs




+ 0
- 25
imperative/python/megengine/experimental/autograd.py View File

@@ -1,25 +0,0 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from ..core._imperative_rt.core2 import (
set_allow_higher_order_directive as _set_allow_higher_order_directive,
)

__all__ = [
"enable_higher_order_directive",
"disable_higher_order_directive",
]


def enable_higher_order_directive():
_set_allow_higher_order_directive(True)


def disable_higher_order_directive():
_set_allow_higher_order_directive(False)

+ 0
- 3
imperative/python/megengine/functional/inplace.py View File

@@ -12,8 +12,5 @@ from ..core.ops.builtin import InplaceAdd


def _inplace_add_(dest, delta, alpha, beta):
isscalar = dest._isscalar()
dest._reset(apply(InplaceAdd(), dest, delta, alpha, beta)[0])
if isscalar:
dest._setscalar()
return dest

+ 1
- 3
imperative/python/megengine/functional/math.py View File

@@ -19,7 +19,7 @@ from ..core.ops import builtin
from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt
from ..core.ops.special import Const
from ..core.tensor import amp
from ..core.tensor.utils import _normalize_axis, cast_tensors, setscalar, subgraph
from ..core.tensor.utils import _normalize_axis, cast_tensors, subgraph
from ..jit import exclude_from_trace
from ..tensor import Tensor
from ..utils.deprecation import deprecated_kwargs_default
@@ -1149,7 +1149,6 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor:
inp1.ndim <= 1 and inp2.ndim <= 1
), "Input tensors for dot must be 1-dimensional or scalar"
(result,) = apply(op, inp1, inp2)
setscalar(result)
return result


@@ -1200,5 +1199,4 @@ def _check_non_finite(inps: Iterable[Tensor], scale=1.0) -> Tensor:
for i in range(len(inps)):
inps[i]._reset(oups[i])

out._setscalar()
return out

+ 0
- 1
imperative/python/megengine/functional/nn.py View File

@@ -35,7 +35,6 @@ from ..core.tensor.utils import (
cast_tensors,
convert_single_value,
make_shape_tuple,
setscalar,
subgraph,
)
from ..device import get_default_device


+ 0
- 9
imperative/python/megengine/functional/tensor.py View File

@@ -972,13 +972,6 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
)
axis = sorted(axis)
assert axis, "axis could not be empty"
if inp._isscalar():
assert axis[0] == 0, "invalid axis {} for ndim 0".format(axis[0])
if len(axis) == 1:
inp = copy(inp, device=None)
inp._unsetscalar()
return inp
axis = axis[1:]
op = builtin.AddAxis(axis=axis)
(result,) = apply(op, inp)
return result
@@ -1164,8 +1157,6 @@ def repeat(inp: Tensor, repeats: int, axis: Optional[int] = None):
if axis is None:
inp = inp.reshape(-1) # flatten
axis = 0
if inp._isscalar():
inp._unsetscalar()
shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device)
# assume inp.ndim is not changed during trace
max_axis = len(shape) - 1


+ 1
- 13
imperative/python/megengine/jit/__init__.py View File

@@ -6,19 +6,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ..core._imperative_rt.core2 import (
set_cpp_apply_const_with_tracing,
set_cpp_apply_with_tracing,
)
from .dtr_config import DTRConfig
from .graph_opt_config import GraphOptimizationConfig
from .sublinear_memory_config import SublinearMemoryConfig
from .tracing import (
apply_const_with_tracing,
apply_with_tracing,
exclude_from_trace,
trace,
)

set_cpp_apply_with_tracing(apply_with_tracing)
set_cpp_apply_const_with_tracing(apply_const_with_tracing)
from .tracing import TraceError, exclude_from_trace, trace

+ 161
- 962
imperative/python/megengine/jit/tracing.py
File diff suppressed because it is too large
View File


+ 35
- 6
imperative/python/megengine/module/module.py View File

@@ -111,6 +111,7 @@ class Module(metaclass=ABCMeta):

# used for profiler and automatic naming
self._name = None
self._short_name = None

@abstractmethod
def forward(self, inputs):
@@ -137,7 +138,7 @@ class Module(metaclass=ABCMeta):
return HookHandler(self._forward_hooks, hook)

def __call__(self, *inputs, **kwargs):
AutoNaming.push_scope(self.name if self.name is not None else self._name)
AutoNaming.push_scope(self.name if self.name is not None else self._short_name)
for hook in self._forward_pre_hooks.values():
modified_inputs = hook(self, inputs)
if modified_inputs is not None:
@@ -641,15 +642,43 @@ class Module(metaclass=ABCMeta):
else:
if modules is not None and name in modules:
modules.remove(name)
for k, v in _expand_structure(name, value):
if not v._name:
v._name = k
elif v._name != k:

def append_name(prefix, name):
if prefix is None or prefix == "":
return name
return prefix + "." + name

def set_name(parent, prefix, name, obj):
if isinstance(obj, Tensor):
assert obj.name is not None
if obj.name != "":
name = obj.name
full_name = append_name(prefix, name)
if obj._short_name and obj._short_name != name:
logger.warning(
"try setting the submodule `{}` to `{}`'s new attribute `{}`, its name `{}` will remain unchanged".format(
type(v), type(self), k, v._name
obj._short_name, type(parent), name, obj._short_name
)
)
return
if isinstance(obj, Tensor):
obj._prefix = prefix
obj._name = full_name
obj._short_name = name
obj._set_name(obj._name)
return obj._name
elif isinstance(obj, Module):
obj._name = full_name
obj._short_name = name
for k, v in obj._flatten(recursive=False, with_key=True):
set_name(obj, full_name, k, v)
return obj._name
else:
assert False

for k, v in _expand_structure(name, value):
prefix = self._name if self._name else self.name
set_name(self, prefix, k, v)
super().__setattr__(name, value)

def __delattr__(self, name: str):


+ 5
- 0
imperative/python/megengine/random/rng.py View File

@@ -14,6 +14,7 @@ from numpy.random import MT19937

from .. import Tensor
from ..core._imperative_rt.core2 import apply
from ..core._imperative_rt.core2 import sync as _sync
from ..core._imperative_rt.ops import delete_rng_handle as _delete_rng_handle
from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed
from ..core._imperative_rt.ops import (
@@ -650,6 +651,10 @@ class RNG:

def __del__(self):
if self._handle != 0:
# RNG op might execute after handle released due to async dispatch, so
# we need sync before delete a handle to avoid memory leak or
# use-after-free
_sync()
_delete_rng_handle(self._handle)




+ 20
- 7
imperative/python/megengine/tensor.py View File

@@ -12,7 +12,7 @@ import numpy as np

from .core._imperative_rt import CompNode
from .core._imperative_rt.core2 import Tensor as _Tensor
from .core._imperative_rt.core2 import apply
from .core._imperative_rt.core2 import apply, set_py_tensor_type
from .core._trace_option import use_symbolic_shape
from .core._wrap import as_device
from .core.ops.builtin import Copy, GetVarShape
@@ -20,7 +20,6 @@ from .core.tensor.array_method import ArrayMethodMixin
from .device import _valid_device, get_default_device
from .logger import get_logger
from .utils.deprecation import deprecated
from .utils.naming import AutoNaming

logger = get_logger(__name__)

@@ -40,6 +39,10 @@ class Tensor(_Tensor, ArrayMethodMixin):
grad = None
dmap_callback = None
_qparams = None
_custom_name = ""
_name = None
_short_name = None
_prefix = None

def __new__(
cls,
@@ -81,9 +84,15 @@ class Tensor(_Tensor, ArrayMethodMixin):
device: str = None,
is_const: bool = False,
no_cache: bool = False,
name: str = None,
name: str = "",
):
pass
if name is None:
name = ""
self._custom_name = name
self._name = name
self._short_name = name
self._set_name(self._name)
self._prefix = None

@property
def shape(self) -> Union[tuple, "Tensor"]:
@@ -151,12 +160,13 @@ class Tensor(_Tensor, ArrayMethodMixin):

@property
def name(self):
return self.c_name
return self._custom_name

@name.setter
def name(self, name):
self.c_name = name
AutoNaming.record_var_name(self._mixin_handle, name)
self._custom_name = name
self._name = self._prefix + "." + name if self._prefix else name
self._set_name(self._name)

@deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0")
def set_value(self, value):
@@ -224,6 +234,9 @@ class Tensor(_Tensor, ArrayMethodMixin):
self._qparams = qparams


set_py_tensor_type(Tensor)


tensor = Tensor




+ 0
- 3
imperative/python/megengine/traced_module/__init__.py View File

@@ -6,7 +6,6 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from ..core._imperative_rt.core2 import set_cpp_apply_module_trace
from . import compat
from ._passes import optimize
from .pytree import register_supported_type
@@ -14,14 +13,12 @@ from .tm_config import disable_default_checker, enable_expr_checker
from .traced_module import (
TracedModule,
_register_all_builtin_module,
cpp_apply_module_trace,
register_as_builtin,
trace_module,
wrap,
)

_register_all_builtin_module()
set_cpp_apply_module_trace(cpp_apply_module_trace)

__all__ = [
"register_as_builtin",


+ 0
- 3
imperative/python/megengine/traced_module/checker.py View File

@@ -13,7 +13,6 @@ import numpy as np
from ..core._imperative_rt.core2 import apply
from ..core._imperative_rt.ops import ROIAlign, ROIPooling
from ..core.ops.builtin import Copy
from ..core.tensor.utils import isscalar, setscalar
from ..tensor import Tensor
from .tm_config import _exclude_from_trace

@@ -70,8 +69,6 @@ class TracedModuleChecker:
self.current_node2values()[node] = apply(
Copy(comp_node=value.device), value
)[0]
if isscalar(value):
setscalar(self.current_node2values()[node])

def check_apply_special_cases(self, opdef, num_outputs):
indexs = list(range(num_outputs))


+ 11
- 2
imperative/python/megengine/traced_module/expr.py View File

@@ -20,6 +20,7 @@ from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import (
apply,
is_tracing_module,
set_module_trace_hook,
set_module_tracing,
unset_module_tracing,
)
@@ -605,8 +606,7 @@ class Apply(Expr):
def apply_module_trace_hook(cls, opdef, *inputs):
for i in inputs:
node = NodeMixin.get(i, None)
if node is None: # capture as constant
NodeMixin.wrap_safe(i, Constant.make(i))
assert node is not None

if isinstance(opdef, FakeQuant):
inp_nodes = [NodeMixin.get(inputs[0])]
@@ -805,3 +805,12 @@ class Constant(Expr):
if isinstance(v, _ModuleState):
state[k] = v.to_module()
self.__dict__.update(state)


def _module_trace_capture(value):
node = Constant.make(value)
NodeMixin.wrap_safe(value, node)
return node


set_module_trace_hook(Apply.apply_module_trace_hook)

+ 0
- 2
imperative/python/megengine/traced_module/module_tracer.py View File

@@ -101,9 +101,7 @@ BUILTIN_TENSOR_WRAP_METHOD = [
"requires_grad",
"_reset",
"_isscalar",
"_setscalar",
"_tuple_shape",
"_unsetscalar",
]




+ 2
- 6
imperative/python/megengine/traced_module/traced_module.py View File

@@ -43,7 +43,6 @@ from ..core._imperative_rt.core2 import (
)
from ..core._trace_option import set_symbolic_shape
from ..core.ops.builtin import Copy
from ..core.tensor.utils import isscalar, setscalar
from ..module import Module
from ..module import external as MExternal
from ..module.qat import QATModule
@@ -1295,12 +1294,9 @@ def _wrapped_function(orig_func):
return orig_func(*args, **kwargs)
if isinstance(args[1], RawTensor):
node = NodeMixin.get(inputs[1])
is_scalar = isscalar(inputs[1])
inputs[1] = apply(
Copy(comp_node=inputs[1].device), Tensor(inputs[1])
)[0]
if is_scalar:
setscalar(inputs[1])
# copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor,
# which will cause they have same _NodeMixin__node in tracing.
NodeMixin.wrap_safe(inputs[1], node)
@@ -2468,8 +2464,8 @@ def trace_module(
try:
net_name = mod._name if mod._name else mod.__class__.__name__
use_sym_shape = set_symbolic_shape(True)
set_module_tracing()
set_active_module_tracer(module_tracer(_wrapped_function))
set_module_tracing()
for cls in [Expr, Node]:
cls._set_next_id(0)
with active_module_tracer().patcher:
@@ -2518,9 +2514,9 @@ def trace_module(
return traced_mod
finally:
set_symbolic_shape(use_sym_shape)
set_active_module_tracer(None)
unset_module_tracing()
for t in mod.tensors(recursive=True):
NodeMixin.clear_node(t)
for t in inputs:
NodeMixin.clear_node(t)
set_active_module_tracer(None)

+ 5
- 0
imperative/python/megengine/utils/profiler.py View File

@@ -137,6 +137,11 @@ class Profiler(ContextDecorator):
get_logger().info("process {} generating {}".format(self._pid, format))
self._dump_callback(path, format)
get_logger().info("profiling results written to {}".format(path))
if os.path.getsize(path) > 64 * 1024 * 1024:
get_logger().warning(
"profiling results too large, maybe you are profiling multi iters,"
"consider attach profiler in each iter separately"
)
self._dump_callback = None
_living_profilers.remove(self)



+ 59
- 625
imperative/python/src/grad.cpp View File

@@ -9,9 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma GCC diagnostic ignored "-Wmissing-field-initializers"

#include "./grad.h"

#include "megbrain/imperative/backward_graph_opt.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/proxy_graph_detail.h"
@@ -19,465 +18,19 @@

#include "range/v3/all.hpp"

#include "./transformation.h"

namespace py = pybind11;
namespace views = ranges::views;

namespace mgb::imperative::python {

using scoped_disable = ApplyContext::scoped_disable;
using Flags = Tensor::Flags;

namespace {

struct GradSlotWeakPtr {
std::weak_ptr<GradFn> grad_fn;
size_t idx;
};

std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph(
ApplyContext& ctx, const apply_result_t& outputs) {
// hash
using OptimizedBackwardGraphCache = OpMethResultCache<
std::shared_ptr<OptimizedBackwardGraphResult>, SmallVector<bool>>;
thread_local OptimizedBackwardGraphCache cache;
decltype(cache)::key_t cache_key{ctx.op};
SmallVector<LogicalTensorDesc>& input_descs = cache_key.inputs;
SmallVector<bool>& input_requires_grad = std::get<0>(cache_key.extras);
input_descs.resize(ctx.nargs);
input_requires_grad.resize(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; ++i) {
input_descs[i].layout.dtype = ctx.args[i]->dtype();
input_descs[i].comp_node = ctx.args[i]->comp_node();
input_requires_grad[i] = python::input_requires_grad(ctx, i);
}

auto iter = cache.find(cache_key);
if (iter != cache.end()) {
return iter->second;
}

// slow path
SmallVector<bool> output_has_grad(outputs.size(), true);
std::shared_ptr<OptimizedBackwardGraphResult> ret;
auto bg = OpDef::make_backward_graph(
*ctx.op, input_descs, input_requires_grad, output_has_grad);
if (!bg.graph.empty()) {
ret = std::make_shared<OptimizedBackwardGraphResult>(bg);
}
cache.emplace(cache_key, ret);
return ret;
std::unordered_map<std::shared_ptr<GradKey>, GradKeyWrapper*> grad_key_map;
}

struct BackwardGraphWithClosure {
std::shared_ptr<OptimizedBackwardGraphResult> backward_graph;
SmallVector<std::shared_ptr<Tensor>> closure;
size_t output_mask_offset;
size_t grad_mask_offset;

BackwardGraphWithClosure(
std::shared_ptr<OptimizedBackwardGraphResult> backward_graph_,
ApplyContext& ctx, const apply_result_t& outputs)
: backward_graph(backward_graph_),
output_mask_offset(ctx.nargs),
grad_mask_offset(ctx.nargs + outputs.size()) {
// save_for_backward[0:nargs]:
// whether input is kept for backward
//
// save_for_backward[nargs:nargs+outputs.size()]:
// whether output is kept for backward
//
// save_for_backward[-outputs.size():]:
// whether gradient of output can propagate to any input
//
// Example:
// perform c = a * b, with a.requires_grad == True and
// b.requires_grad == False, save_for_backward = [0, 1, 0, 1]
auto& save_for_backward = backward_graph->save_for_backward;
mgb_assert(save_for_backward.size() == ctx.nargs + 2 * outputs.size());
size_t count = std::count_if(
save_for_backward.begin(), save_for_backward.end(), ranges::identity{});
if (!backward_graph->precomp.empty()) {
auto&& irng = ranges::span(ctx.args, ctx.nargs);
auto&& orng = views::transform(outputs, [](auto&& i) { return i.get(); });
auto precomp = apply(backward_graph->precomp, views::concat(irng, orng));
closure.reserve(precomp.size() + count);
std::copy(precomp.begin(), precomp.end(), std::back_inserter(closure));
} else {
closure.reserve(count);
}
for (size_t i = 0; i < ctx.nargs; ++i) {
if (save_for_backward[i]) {
closure.push_back(ctx.args[i]->shared_from_this());
}
}
for (size_t i = 0; i < outputs.size(); ++i) {
if (save_for_backward[ctx.nargs + i]) {
closure.push_back(outputs[i]);
}
}
}

template <typename T, typename R>
void operator()(BackwardContext&, T&& grads, R&& receiver) {
Tensor* args[closure.size() + grads.size()];
size_t nargs = 0;
for (auto&& t : closure) {
args[nargs++] = t.get();
}
bool null_grad = false;
for (size_t i = 0; i < grads.size(); ++i) {
if (backward_graph->save_for_backward[grad_mask_offset + i]) {
if (grads[i]) {
if (null_grad) {
PyErr_SetString(PyExc_NotImplementedError, "report to devs");
throw py::error_already_set();
}
args[nargs++] = grads[i];
} else {
null_grad = true;
}
}
}
if (null_grad)
return;

auto igrads = apply(backward_graph->backward, args, nargs);
auto&& it = igrads.begin();
for (auto [i, p] : views::enumerate(backward_graph->input_has_grad)) {
if (p) {
receiver(i, std::move(*it));
++it;
}
}
}

bool input_has_grad(size_t i) { return backward_graph->input_has_grad[i]; }

bool output_requires_grad(size_t i) {
return backward_graph->save_for_backward[grad_mask_offset + i];
}

bool output_captured(size_t i) {
return backward_graph->save_for_backward[output_mask_offset + i];
}
};

struct PythonBackward {
py::object pyfunc;
size_t input_size;

PythonBackward(py::object f, size_t nin) : pyfunc(f), input_size(nin) {}

template <typename T, typename R>
void operator()(BackwardContext& ctx, T&& grads, R&& receiver) {
auto args = py::tuple(grads.size());
for (size_t i = 0; i < grads.size(); ++i) {
auto&& g = grads[i];
args[i] = g ? ctx.wrap_tensor(g) : py::none();
}
auto input_grads = py::reinterpret_steal<py::object>(
PyObject_Call(pyfunc.ptr(), args.ptr(), nullptr));
if (!input_grads)
throw py::error_already_set();
if (input_grads.is_none())
return;
if (auto* tw = TensorWrapper::try_cast(input_grads.ptr())) {
if (input_size != 1) {
throw py::value_error(
"custom grad rule returned wrong number of grads");
}
if (!ctx.pytype) {
ctx.pytype = Py_TYPE(input_grads.ptr());
}
receiver(0, tw->m_tensor);
return;
}
if (py::len(input_grads) != input_size) {
throw py::value_error("custom grad rule returned wrong number of grads");
}
for (auto [i, g] : views::enumerate(input_grads)) {
if (g.is_none())
continue;
auto* tw = TensorWrapper::try_cast(g.ptr());
if (!tw) {
throw py::type_error("custom grad rule returned non-tensor");
}
if (!ctx.pytype) {
ctx.pytype = Py_TYPE(g.ptr());
}
receiver(i, tw->m_tensor);
}
}

static constexpr bool input_has_grad(size_t) { return true; }
static constexpr bool output_requires_grad(size_t) { return true; }
static constexpr bool output_captured(size_t) { return true; }
};

} // namespace

struct GradProducerRecord : intrusive_list::Node<GradProducerRecord> {
using Base = intrusive_list::Node<GradProducerRecord>;

GradProducerRecord() = default;
GradProducerRecord(GradProducerRecord::head_t& head)
: Base(intrusive_list::after_t{}, head) {}
// GradProducerRecord(GradProducerRecord&&) = default;
// GradProducerRecord& operator=(GradProducerRecord&) = default;
// GradProducerRecord& operator=(GradProducerRecord&&) = default;
};

struct GradSlot {
std::shared_ptr<Tensor> grad;
py::object callback;
GradProducerRecord::head_t producer_head;
};

struct GradSlotProducerPtr : GradSlotPtr {
GradProducerRecord producer_record;

GradSlotProducerPtr() = default;
GradSlotProducerPtr(GradInfo& info)
: GradSlotPtr(info), producer_record(info->producer_head) {}
};

struct GradFn : std::enable_shared_from_this<GradFn> {
static MemPool<GradFn> pool;

std::weak_ptr<GradKey> key;
// slots for receiving and accumulating grads
// same length as outputs (of forward op)
SmallVector<GradSlot> slots;
// where to send and accumulate grads
// same length as inputs (of forward op)
SmallVector<GradSlotProducerPtr> dsts;
// encapsules actual function to compute gradient
std::variant<
std::monostate, BackwardGraphWithClosure, PythonBackward, CustomBackward>
backward;
// a flag used during backward
bool in_ref_keeper = false;

static void deleter(GradFn* ptr) { pool.free(ptr); }

static std::shared_ptr<GradFn> make() {
return std::shared_ptr<GradFn>(pool.alloc(), &deleter);
}

void clear() {
key.reset();
slots.clear();
dsts.clear();
backward.emplace<std::monostate>();
}
};

GradSlotPtr::operator bool() const {
return bool(grad_fn);
}

GradSlot* GradSlotPtr::operator->() {
return &grad_fn->slots[idx];
}

namespace {

class GradFnHelper {
std::shared_ptr<GradFn> grad_fn;

GradFn* get() {
if (!grad_fn) {
grad_fn = std::make_shared<GradFn>();
}
return grad_fn.get();
}

friend apply_result_t imperative::python::apply_grad(ApplyContext&);

public:
template <typename T, typename... Args>
auto& emplace(Args&&... args) {
return get()->backward.emplace<T>(std::forward<Args>(args)...);
}

void reset() { grad_fn = nullptr; }
};

apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
// copy inputs first, or trace will make InputNodes for each usage
ApplyContext ctx_dup = ctx;
SmallVector<std::shared_ptr<Tensor>> inputs_copy;
SmallVector<Tensor*> inputs_copy_weak;
for (size_t i = 0; i < ctx.nargs; ++i) {
Tensor* input = ctx.args[i];
inputs_copy.push_back(python::apply(FastpathCopy::make(), input)[0]);
inputs_copy_weak.push_back(inputs_copy.back().get());
inputs_copy.back()->m_grad_info_dict = ctx.args[i]->m_grad_info_dict;
if (input->m_flags & Flags::GRAD) {
inputs_copy.back()->m_flags |= Flags::GRAD;
}
}
ctx_dup.args = inputs_copy_weak.data();

auto outputs = apply(ctx_dup);

auto backward_graph = make_backward_graph(ctx_dup, outputs);
if (!backward_graph) {
return outputs;
}
ret_grad_fn.emplace<BackwardGraphWithClosure>(
std::move(backward_graph), ctx_dup, outputs);

return outputs;
}

apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
auto* op = ctx.op->try_cast_final<GenericPyOp>();
py::tuple pyin(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; ++i) {
pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this());
}
auto grad_rule = py::getattr(op->obj, "_grad_rule");
auto pyret = py::reinterpret_steal<py::object>(
PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr));
if (!pyret)
throw py::error_already_set();
auto [outputs, backward] = py::cast<std::tuple<py::object, py::function>>(pyret);
ret_grad_fn.emplace<PythonBackward>(std::move(backward), ctx.nargs);
if (auto* tw = TensorWrapper::try_cast(outputs.ptr())) {
return {tw->m_tensor};
}
apply_result_t ret;
ret.reserve(py::len(outputs));
for (auto&& i : outputs) {
auto* tw = TensorWrapper::try_cast(i.ptr());
mgb_assert(tw);
ret.push_back(tw->m_tensor);
}
return ret;
}

} // namespace

apply_result_t apply_grad(ApplyContext& ctx) {
std::unordered_set<std::shared_ptr<GradKey>> grad_keys;
for (size_t i = 0; i < ctx.nargs; ++i) {
auto* tensor = ctx.args[i];
if (!tensor->m_grad_info_dict.empty()) {
size_t grad_cnt = 0;
for (auto&& grad_info : tensor->m_grad_info_dict) {
auto input_grad_key = grad_info.grad_fn->key.lock();
if (input_grad_key && input_grad_key->active &&
!input_grad_key->is_blocked()) {
grad_keys.insert(input_grad_key);
grad_cnt++;
}
}
if (!grad_cnt) {
tensor->m_flags &= ~Flags::GRAD;
}
} else {
tensor->m_flags &= ~Flags::GRAD;
}
}

ctx.flags &= ~Flags::GRAD;

if (grad_keys.empty()) {
return apply(ctx);
} else if (grad_keys.size() > 1 && !GradKey::allow_higher_order_directive) {
PyErr_SetString(
PyExc_NotImplementedError,
"second order directive not enabled, please call "
"'megengine.experimental.enable_higher_order_directive'");
throw pyext17::py_err_set();
}

GradFnHelper grad_fn_holder;
auto outputs = [&]() {
auto _ = scoped_disable(Flags::GRAD);
if (ctx.op->same_type<GenericPyOp>()) {
return python_grad_rule(ctx, grad_fn_holder);
}
auto&& registry = grad_rule_registry();
auto&& it = registry.find(ctx.op->dyn_typeinfo());
if (it != registry.end()) {
auto&& maker = grad_fn_holder.emplace<CustomBackward>().maker(ctx);
if (auto ret = it->second(ctx, maker)) {
maker.finalize();
return *ret;
}
grad_fn_holder.reset();
}
return backward_graph_grad_rule(ctx, grad_fn_holder);
}();

if (!grad_fn_holder.grad_fn) {
return outputs;
}

for (auto&& grad_key : grad_keys) {
auto grad_fn = std::make_shared<GradFn>();
grad_fn->backward = grad_fn_holder.grad_fn->backward;
grad_fn->key = grad_key;
grad_fn->slots.resize(outputs.size());
grad_fn->dsts.reserve(ctx.nargs);

std::visit(
[&](auto& backward) {
using T = std::decay_t<decltype(backward)>;
if constexpr (std::is_same_v<T, std::monostate>) {
mgb_assert(0);
} else {
for (size_t i = 0; i < ctx.nargs; ++i) {
if (backward.input_has_grad(i) &&
input_requires_grad(ctx, i) &&
ctx.args[i]->m_grad_info_dict.count(grad_key.get())) {
auto& input_grad_info =
ctx.args[i]->m_grad_info_dict.at(
grad_key.get());
grad_fn->dsts.emplace_back(input_grad_info);
// register as grad producer
grad_fn->dsts.back().producer_record.insert_after(
input_grad_info->producer_head);
} else {
grad_fn->dsts.emplace_back();
}
}
for (size_t i = 0; i < outputs.size(); ++i) {
if (backward.output_requires_grad(i)) {
if (backward.output_captured(i)) {
// avoid reference cycle [Tensor <-> GradFn]
static std::shared_ptr<OpDef> op =
std::make_shared<FastpathCopy>();
outputs[i] = python::apply(op, outputs[i])[0];
}
// populate grad info of output tensor
auto& grad_info =
outputs[i]->m_grad_info_dict[grad_key.get()];
grad_info.grad_fn = grad_fn;
grad_info.idx = i;
grad_info.insert_after(grad_key->free_vars_head);
outputs[i]->m_flags |= Flags::GRAD;
}
}
}
},
grad_fn->backward);

// record forward history
grad_key->tape.emplace_back(grad_fn);
}

return outputs;
}

PyObject* GradKeyWrapper::get_priority() {
return py::cast(m_key->priority).release().ptr();
}

void GradKeyWrapper::set_priority(pybind11::handle priority) {
m_key->priority = py::cast<int>(priority);
GradKeyWrapper::GradKeyWrapper() : m_key(std::make_shared<GradKey>()) {
grad_key_map[m_key] = this;
}

void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) {
@@ -488,157 +41,59 @@ void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) {
if (!tw) {
throw py::type_error("argument 1 must be Tensor");
}
auto* tensor = tw->m_tensor.get();
py::object callback;
if (args[1] != Py_None) {
callback = py::reinterpret_borrow<py::object>(args[1]);
}
m_key->attach(tensor, std::move(callback));
}

//! GradKey is weakly refered by tensor->m_grad_info.grad_fn->key after attach
void GradKey::attach(Tensor* tensor, pybind11::object callback) {
if (!active) {
throw py::value_error("grad key finalized");
}

if (tensor->m_grad_info_dict.count(this)) {
if (tensor->m_grad_info_dict.at(this)->callback) {
throw py::value_error("callback already set on this tensor");
GenericFunction generic_callback =
[=](Span<ValueRef> inputs) -> std::vector<ValueRef> {
mgb_assert(inputs.size() == 1);
if (callback) {
callback(TensorWrapper::make(py_tensor_type, inputs[0]));
}
} else {
auto& grad_info = tensor->m_grad_info_dict[this];
grad_info.idx = 0;
auto& grad_fn = grad_info.grad_fn;
grad_fn = std::make_shared<GradFn>();
grad_fn->key = shared_from_this();
grad_fn->slots.resize(1);
grad_info.insert_after(free_vars_head);
tensor->m_flags |= Flags::GRAD;
}
tensor->m_grad_info_dict.at(this).grad_fn->slots[0].callback = std::move(callback);
}

template <typename T>
void accum_grad(std::shared_ptr<Tensor>& grad, T&& delta) {
if (!grad) {
grad = std::forward<T>(delta);
return;
}
static std::shared_ptr<OpDef> op =
std::shared_ptr<OpDef>(new Elemwise(Elemwise::Mode::ADD));
grad = apply(op, grad, std::forward<T>(delta))[0];
return {};
};
tw->m_tensor->reset(imperative::apply(
AttachGrad(m_key), tw->m_tensor->data(),
FunctionValue::make(generic_callback))[0]);
}

void GradKey::backward(
std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) {
if (!active) {
throw py::value_error("finalized");
void GradKeyWrapper::backward(GradKeyWrapper* self, py::list tensors, py::list grads) {
std::vector<ValueRef> args;
mgb_assert(tensors.size() == grads.size());
for (auto&& tensor : tensors) {
args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data());
}
if (tensors.size() != grads.size()) {
throw py::value_error("tensor and grad size mismatch");
for (auto&& grad : grads) {
args.push_back(TensorWrapper::try_cast(grad.ptr())->m_tensor->data());
}

// this GradKey is marked inactive here
active = false;
struct CleanupGuard {
GradKey* owner;
size_t priority_backup;
CleanupGuard(GradKey* this_) : owner(this_) {
priority_backup = sm_min_priority;
sm_min_priority = owner->priority + 1;
}
~CleanupGuard() {
owner->cleanup();
sm_min_priority = priority_backup;
}
} _cleanup_guard(this);

if (tape.empty())
return;

BackwardContext bctx;
if (!grads.empty()) {
bctx.pytype = Py_TYPE(grads[0]->self().ptr());
}

for (size_t i = 0; i < tensors.size(); ++i) {
if (tensors[i]->m_tensor->m_grad_info_dict.count(this) == 0) {
continue;
}
auto& grad_info = tensors[i]->m_tensor->m_grad_info_dict.at(this);
grad_info->grad = grads[i]->m_tensor;
}

std::vector<std::shared_ptr<GradFn>> ref_keeper;
ref_keeper.reserve(tape.size());

// back-propagation in reverse order
for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) {
auto&& grad_fn = tape[k].lock();
if (!grad_fn)
continue;

auto grad_receiver = [&](size_t i, auto&& g) {
auto& dst = grad_fn->dsts[i];
if (dst) {
accum_grad(dst->grad, std::forward<decltype(g)>(g));
}
};
std::visit(
[&](auto&& backward) {
using T = std::decay_t<decltype(backward)>;
if constexpr (std::is_same_v<T, std::monostate>) {
mgb_assert(0);
} else {
auto&& grads = views::transform(
grad_fn->slots,
[](auto&& slot) { return slot.grad.get(); });
backward(
bctx, std::forward<decltype(grads)>(grads),
grad_receiver);
}
},
grad_fn->backward);

for (auto&& dst : grad_fn->dsts) {
if (!dst.grad_fn)
continue;
if (!dst.grad_fn->in_ref_keeper) {
// after grad_fn is cleared, refcnt of subsequent grad_fn
// could drop to 0
dst.grad_fn->in_ref_keeper = true;
ref_keeper.push_back(dst.grad_fn);
}
if (!dst.producer_record.next && dst->callback && dst->grad) {
// I'm the last grad producer, invoke callback
dst->callback(bctx.wrap_tensor(dst->grad));
}
}
grad_fn->clear();
} // finish tape loop
imperative::apply(GradBackward(self->m_key), {args.data(), args.size()});
}

void GradKey::cleanup() {
active = false;
tape.clear();
for (intrusive_list::Iterator it(free_vars_head); it;) {
it->grad_fn.reset();
(it++)->unlink();
pybind11::function GradKeyWrapper::get_backward_closure(
GradKeyWrapper* self, py::list tensors) {
std::vector<ValueRef> args;
for (auto&& tensor : tensors) {
args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data());
}
}

void GradKeyWrapper::backward(
std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) {
m_key->backward(std::move(tensors), std::move(grads));
auto closure = imperative::apply(GetBackwardColsure(self->m_key), args)[0]
.as<FunctionValue>();
auto py_function = [closure](std::vector<TensorWrapper*> tensors) {
std::vector<ValueRef> args;
for (auto* tw : tensors) {
args.push_back(tw->m_tensor->data());
}
(*closure)(args);
};
return pybind11::cpp_function(py_function);
}

PyObject* GradKeyWrapper::get_name() {
return py::cast(m_key->name).release().ptr();
return py::cast(m_key->name()).release().ptr();
}

void GradKeyWrapper::set_name(py::handle name) {
m_key->name = py::cast<std::string>(name);
m_key->name(py::cast<std::string>(name));
}

PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) {
@@ -651,60 +106,39 @@ PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) {
PyErr_SetString(PyExc_TypeError, "expect Tensor");
return nullptr;
}
if (tw->m_tensor->m_grad_info_dict.count(m_key.get())) {
if (imperative::apply(IsAttachedTo(m_key), tw->m_tensor->data())[0]
.cast<BoolValue>()) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
}

int GradKey::sm_min_priority = std::numeric_limits<int>::min();
GradKey::~GradKey() {
cleanup();
void GradKeyWrapper::enter() {
m_transformation = std::make_shared<GradTransformation>(m_key);
TransformationManager::get_instance().register_at<TransformationManager::Grad>(
m_transformation);
}

std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry() {
static std::unordered_map<Typeinfo*, GradRuleFn> registry;
return registry;
void GradKeyWrapper::exit() {
TransformationManager::get_instance().unregister<TransformationManager::Grad>(
m_transformation);
m_transformation.reset();
}

void GradInfoCollection::_shrink() {
auto pred = [](GradInfo& info) {
return !(info.grad_fn) || info.grad_fn->key.expired();
};
auto iter = std::remove_if(m_storage.begin(), m_storage.end(), pred);
m_storage.erase(iter, m_storage.end());
void GradKeyWrapper::suppress() {
m_transformation->suppress();
}

bool GradInfoCollection::contains(GradKey* key) {
_shrink();
for (auto&& grad_info : m_storage) {
if (grad_info.grad_fn->key.lock().get() == key) {
return true;
}
}
return false;
void GradKeyWrapper::resume() {
m_transformation->resume();
}

GradInfo& GradInfoCollection::operator[](GradKey* key) {
_shrink();
for (auto&& grad_info : m_storage) {
if (grad_info.grad_fn->key.lock().get() == key) {
return grad_info;
}
}
m_storage.emplace_back();
return m_storage.back();
GradKeyWrapper* GradKeyWrapper::get(std::shared_ptr<GradKey> key) {
return grad_key_map.at(key);
}

GradInfo& GradInfoCollection::at(GradKey* key) {
_shrink();
for (auto&& grad_info : m_storage) {
if (grad_info.grad_fn->key.lock().get() == key) {
return grad_info;
}
}
mgb_assert(false);
GradKeyWrapper::~GradKeyWrapper() {
grad_key_map.erase(m_key);
}

} // namespace mgb::imperative::python

+ 15
- 141
imperative/python/src/grad.h View File

@@ -12,166 +12,40 @@
#pragma once

#include "./tensor.h"

#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/transformations/grad.h"
#include "megbrain/utils/small_vector.h"

#include <megbrain/utils/small_vector.h>
#include <memory>
#include <optional>

namespace mgb::imperative::python {

apply_result_t apply_grad(ApplyContext& ctx);

struct GradKey : std::enable_shared_from_this<GradKey>, NonCopyableObj {
std::string name;
bool active = true;
GradInfo::head_t free_vars_head;
std::vector<std::weak_ptr<GradFn>> tape;
int priority = 0;

~GradKey();

void attach(Tensor* tensor, pybind11::object callback);
void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>);
void cleanup();
bool is_blocked() const { return priority < sm_min_priority; }
inline static bool allow_higher_order_directive = false;

private:
static int sm_min_priority;
};

struct GradKeyWrapper {
struct GradKeyWrapper : NonCopyableObj {
using wrap_t = pyext17::wrap<GradKeyWrapper>;
static constexpr auto tp_name = pybind11::detail::_("GradKey");

std::shared_ptr<GradKey> m_key;
std::shared_ptr<GradTransformation> m_transformation;

inline GradKeyWrapper() : m_key(std::make_shared<GradKey>()) {}
GradKeyWrapper();

PyObject* get_name();
void set_name(pybind11::handle name);
PyObject* get_priority();
void set_priority(pybind11::handle priority);
void attach(PyObject* const* args, size_t nargs);
void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>);
static void backward(GradKeyWrapper* self, pybind11::list, pybind11::list);
static pybind11::function get_backward_closure(
GradKeyWrapper* self, pybind11::list);
PyObject* is_attached_to(PyObject* const* args, size_t nargs);
void enter();
void exit();
void suppress();
void resume();
static GradKeyWrapper* get(std::shared_ptr<GradKey> key);
~GradKeyWrapper();
};

struct BackwardContext {
PyTypeObject* pytype = nullptr;

auto wrap_tensor(std::shared_ptr<Tensor> t) {
if (pytype) {
return TensorWrapper::make(pytype, std::move(t));
}
return TensorWrapper::make(std::move(t));
}

auto wrap_tensor(Tensor* t) { return wrap_tensor(t->shared_from_this()); }
};

struct CustomBackward {
using BackwardFn =
std::function<apply_result_t(BackwardContext&, Tensor* const*, size_t)>;
BackwardFn m_backward;
SmallVector<bool, 8> m_input_has_grad;
struct OutputAttr {
bool requires_grad = true, captured = true;
};
SmallVector<OutputAttr> m_output_attrs;

public:
template <typename T, typename R>
void operator()(BackwardContext& ctx, T&& grads, R&& receiver) {
size_t nargs = grads.size();
Tensor* args[nargs];
for (size_t i = 0; i < nargs; ++i) {
args[i] = grads[i];
}
auto ret = m_backward(ctx, args, nargs);
for (size_t i = 0; i < ret.size(); ++i) {
if (auto&& t = ret[i]) {
receiver(i, std::move(t));
}
}
}

bool input_has_grad(size_t i) { return m_input_has_grad[i]; }
bool output_requires_grad(size_t i) { return m_output_attrs[i].requires_grad; }
bool output_captured(size_t i) { return m_output_attrs[i].captured; }

class Maker {
bool output_size_set = false, input_has_grad_initialized = false;
CustomBackward& target;
ApplyContext& ctx;

void init_input_has_grad() {
if (!input_has_grad_initialized) {
input_has_grad_initialized = true;
target.m_input_has_grad.resize(ctx.nargs, true);
}
}

public:
Maker(CustomBackward& target_, ApplyContext& ctx_)
: target(target_), ctx(ctx_) {}

template <typename F>
Maker& backward(F&& f) {
mgb_assert(!target.m_backward);
target.m_backward = std::forward<F>(f);
return *this;
}
// mandatory
Maker& output_size(size_t sz) {
mgb_assert(!output_size_set);
output_size_set = true;
target.m_output_attrs.resize(sz);
return *this;
}
// optional, defaults to all true
Maker& input_has_grad(size_t i, bool v) {
init_input_has_grad();
target.m_input_has_grad.at(i) = v;
return *this;
}
// optional, defaults to all true
Maker& output_requires_grad(size_t i, bool v) {
target.m_output_attrs.at(i).requires_grad = v;
return *this;
}
// optional, defaults to all true
Maker& output_captured(size_t i, bool v) {
target.m_output_attrs.at(i).captured = v;
return *this;
}

void finalize() {
mgb_assert(output_size_set);
init_input_has_grad();
}
};

Maker maker(ApplyContext& ctx) { return {*this, ctx}; }
};

using GradRuleFn = std::function<std::optional<apply_result_t>(
ApplyContext&, CustomBackward::Maker&)>;

std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry();

inline bool input_requires_grad(const ApplyContext& ctx, size_t i) {
return !ctx.args[i]->m_grad_info_dict.empty();
}

struct GradRuleFallback : std::exception {};

template <typename T>
bool register_grad_rule(Typeinfo* typeinfo, T&& rule) {
return grad_rule_registry().emplace(typeinfo, std::forward<T>(rule)).second;
}

} // namespace mgb::imperative::python

namespace pybind11::detail {


+ 0
- 43
imperative/python/src/grad_info.h View File

@@ -1,43 +0,0 @@
/**
* \file imperative/python/src/grad_info.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include <memory>

#include "./intrusive_list.h"

namespace mgb::imperative::python {

struct GradKey;
struct GradFn;
struct GradSlot;

struct GradSlotPtr {
std::shared_ptr<GradFn> grad_fn;
size_t idx;

operator bool() const;
GradSlot* operator->();
};

struct GradInfo : GradSlotPtr,
intrusive_list::Node<GradInfo, intrusive_list::before_t> {
GradInfo() = default;
GradInfo(GradInfo&) = default;
GradInfo(GradInfo&&) = default;
GradInfo& operator=(GradInfo&) = default;
GradInfo& operator=(GradInfo&&) = default;
GradInfo(const GradInfo& rhs) : GradInfo(const_cast<GradInfo&>(rhs)) {}
GradInfo& operator=(const GradInfo& rhs) {
return *this = const_cast<GradInfo&>(rhs);
}
};

} // namespace mgb::imperative::python

+ 237
- 164
imperative/python/src/grad_override.cpp View File

@@ -11,261 +11,334 @@

#include "./grad.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/transformations/grad.h"

namespace mgb::imperative::python {

class CustomGradMaker {
bool output_size_set = false, input_has_grad_initialized = false;
CustomBackward& target;
size_t nr_inputs;
void init_input_has_grad() {
if (!input_has_grad_initialized) {
input_has_grad_initialized = true;
target.m_input_has_grad.resize(nr_inputs, true);
}
}

public:
CustomGradMaker(CustomBackward& target, size_t nr_inputs)
: target(target), nr_inputs(nr_inputs) {}

CustomGradMaker& backward(CustomBackward::BackwardFn f) {
mgb_assert(!target.m_backward);
target.m_backward = f;
return *this;
}
// mandatory
CustomGradMaker& output_size(size_t sz) {
mgb_assert(!output_size_set);
output_size_set = true;
target.m_output_attrs.resize(sz);
return *this;
}
// optional, defaults to all true
CustomGradMaker& input_has_grad(size_t i, bool v) {
init_input_has_grad();
target.m_input_has_grad.at(i) = v;
return *this;
}
// optional, defaults to all true
CustomGradMaker& output_requires_grad(size_t i, bool v) {
target.m_output_attrs.at(i).requires_grad = v;
return *this;
}
// optional, defaults to all true
CustomGradMaker& output_captured(size_t i, bool v) {
target.m_output_attrs.at(i).captured = v;
return *this;
}
void finalize() {
mgb_assert(output_size_set);
init_input_has_grad();
}
};

namespace {

std::shared_ptr<Tensor> get_shape(Tensor* x) {
ValueRef get_shape(ValueRef x) {
static auto op = GetVarShape::make();
return python::apply(op, x)[0];
return imperative::apply(*op, x)[0];
}

std::shared_ptr<Tensor> reduce_to(Tensor* x, Tensor* s) {
ValueRef reduce_to(ValueRef x, ValueRef s) {
static auto op = Reduce::make();
return python::apply(op, x, s)[0];
return imperative::apply(*op, x, s)[0];
}

std::shared_ptr<Tensor> reshape_to(Tensor* x, Tensor* s) {
ValueRef reshape_to(ValueRef x, ValueRef s) {
static auto op = Reshape::make();
return python::apply(op, x, s)[0];
return imperative::apply(*op, x, s)[0];
}

std::shared_ptr<Tensor> broadcast_to(Tensor* x, Tensor* s) {
ValueRef broadcast_to(ValueRef x, ValueRef s) {
static auto op = Broadcast::make();
return python::apply(op, x, s)[0];
return imperative::apply(*op, x, s)[0];
}

std::shared_ptr<Tensor> make_empty_tensor(CompNode cn, Tensor* shape, DType dtype) {
HostTensorND scalar{cn, {{1}, dtype}};
std::memset(scalar.raw_ptr(), 0, dtype.size());
interpreter::Interpreter::Handle handle = interpreter_for_py->put(scalar, false);
auto&& t = std::make_shared<Tensor>(handle);
auto res = broadcast_to(t.get(), shape);
ValueRef make_empty_tensor(
CompNodeValue::ref_t device, ValueRef shape, DTypeValue::ref_t dtype) {
HostTensorStorage storage(*device);
storage.ensure_size(dtype->size());
std::memset(storage.ptr(), 0, dtype->size());
auto t = imperative::apply(
CreateTensor(CreateTensor::Unique, *device, *dtype, ValueShape()),
HostStorage::make(storage))[0];
auto res = broadcast_to(t, shape);
return res;
}

std::optional<apply_result_t> elemwise_grad_rule(
ApplyContext& ctx, CustomBackward::Maker& maker) {
auto& op = ctx.op->cast_final_safe<Elemwise>();
if (op.mode == Elemwise::Mode::ADD) {
mgb_assert(ctx.nargs == 2);
std::array<std::shared_ptr<Tensor>, 2> input_shapes;
std::optional<std::vector<ValueRef>> elemwise_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto& elemwise = op.cast_final_safe<Elemwise>();
if (elemwise.mode != Elemwise::Mode::ADD) {
return {};
}
mgb_assert(inputs.size() == 2);
std::array<ValueRef, 2> input_shapes;
for (size_t i = 0; i < 2; ++i) {
if (inputs_require_grad[i]) {
input_shapes[i] = get_shape(inputs[i]);
}
}
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false);
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
std::vector<ValueRef> ret(2);
if (!grad) {
return ret;
}
for (size_t i = 0; i < 2; ++i) {
if (input_requires_grad(ctx, i)) {
input_shapes[i] = get_shape(ctx.args[i]);
if (shapes[i]) {
ret[i] = reduce_to(grad, shapes[i]);
}
}
maker.output_size(1).output_captured(0, false);
maker.backward([shapes = std::move(input_shapes)](
BackwardContext&, Tensor* const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(2);
if (!grad) {
return ret;
}
for (size_t i = 0; i < 2; ++i) {
if (shapes[i]) {
ret[i] = reduce_to(grad, shapes[i].get());
}
}
return ret;
});
return apply(ctx);
}
return {};
return ret;
});
maker.finalize();
return imperative::apply(ApplyOp(op), inputs);
}

std::optional<apply_result_t> reshape_grad_rule(
ApplyContext& ctx, CustomBackward::Maker& maker) {
mgb_assert(ctx.nargs == 2);
std::array<std::shared_ptr<Tensor>, 2> input_shapes;
std::optional<std::vector<ValueRef>> reshape_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
mgb_assert(inputs.size() == 2);
std::array<ValueRef, 2> input_shapes;
for (size_t i = 0; i < 2; ++i) {
if (input_requires_grad(ctx, i)) {
input_shapes[i] = get_shape(ctx.args[i]);
if (inputs_require_grad[i]) {
input_shapes[i] = get_shape(inputs[i]);
}
}
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false);
maker.backward([shapes = std::move(input_shapes)](
BackwardContext&, Tensor* const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(2);
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
std::vector<ValueRef> ret(2);
if (!grad) {
return ret;
}
for (size_t i = 0; i < 2; ++i) {
if (shapes[i]) {
ret[i] = reshape_to(grad, shapes[i].get());
ret[i] = reshape_to(grad, shapes[i]);
}
}
return ret;
});
return apply(ctx);
maker.finalize();
return imperative::apply(ApplyOp(op), inputs);
}

std::optional<apply_result_t> subtensor_grad_rule(
ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<Subtensor>();
auto&& grad_op = SetSubtensor::make(op.items);
SmallVector<std::shared_ptr<Tensor>> inputs;
if (input_requires_grad(ctx, 0)) {
inputs.push_back(get_shape(ctx.args[0]));
for (size_t i = 1; i < ctx.nargs; ++i) {
inputs.push_back(ctx.args[i]->copy());
std::optional<std::vector<ValueRef>> subtensor_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto&& subtensor = op.cast_final_safe<Subtensor>();
auto&& grad_op = SetSubtensor::make(subtensor.items);
SmallVector<ValueRef> inputs2;
if (inputs_require_grad[0]) {
inputs2.push_back(get_shape(inputs[0]));
for (size_t i = 1; i < inputs.size(); ++i) {
inputs2.push_back(inputs[i]);
}
}
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false);
maker.backward([inputs = std::move(inputs), grad_op_ = std::move(grad_op)](
BackwardContext&, Tensor* const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(1);
maker.backward([inputs = std::move(inputs2),
grad_op_ = std::move(grad_op)](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
std::vector<ValueRef> ret(1);
if (grad && inputs[0]) {
SmallVector<Tensor*> args_(inputs.size() + 1);
auto&& zeros = make_empty_tensor(
grad->comp_node(), inputs[0].get(), grad->dtype());
args_[0] = zeros.get();
SmallVector<ValueRef> args_(inputs.size() + 1);
auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype());
args_[0] = zeros;
args_[1] = grad;
for (size_t i = 1; i < inputs.size(); ++i) {
args_[i + 1] = inputs[i].get();
args_[i + 1] = inputs[i];
}
ret[0] = python::apply(grad_op_, args_)[0];
ret[0] = imperative::apply(ApplyOp(*grad_op_), args_)[0];
}
return ret;
});
return apply(ctx);
maker.finalize();
return imperative::apply(ApplyOp(op), inputs);
}

std::optional<apply_result_t> indexingMultiAxisVec_grad_rule(
ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<IndexingMultiAxisVec>();
auto&& grad_op = IndexingSetMultiAxisVec::make(op.items);
SmallVector<std::shared_ptr<Tensor>> inputs;
if (input_requires_grad(ctx, 0)) {
inputs.push_back(get_shape(ctx.args[0]));
for (size_t i = 1; i < ctx.nargs; ++i) {
inputs.push_back(ctx.args[i]->copy());
std::optional<std::vector<ValueRef>> indexingMultiAxisVec_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto&& indexingMultiAxisVec = op.cast_final_safe<IndexingMultiAxisVec>();
auto&& grad_op = IndexingSetMultiAxisVec::make(indexingMultiAxisVec.items);
SmallVector<ValueRef> inputs2;
if (inputs_require_grad[0]) {
inputs2.push_back(get_shape(inputs[0]));
for (size_t i = 1; i < inputs.size(); ++i) {
inputs2.push_back(inputs[i]);
}
}
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false);
maker.backward([inputs = std::move(inputs), grad_op_ = std::move(grad_op)](
BackwardContext&, Tensor* const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(1);
maker.backward([inputs = std::move(inputs2),
grad_op_ = std::move(grad_op)](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
std::vector<ValueRef> ret(1);
if (grad && inputs[0]) {
SmallVector<Tensor*> args_(inputs.size() + 1);
auto&& zeros = make_empty_tensor(
grad->comp_node(), inputs[0].get(), grad->dtype());
args_[0] = zeros.get();
SmallVector<ValueRef> args_(inputs.size() + 1);
auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype());
args_[0] = zeros;
args_[1] = grad;
for (size_t i = 1; i < inputs.size(); ++i) {
args_[i + 1] = inputs[i].get();
args_[i + 1] = inputs[i];
}
ret[0] = python::apply(grad_op_, args_)[0];
ret[0] = imperative::apply(ApplyOp(*grad_op_), args_)[0];
}
return ret;
});
return apply(ctx);
maker.finalize();
return imperative::apply(ApplyOp(op), inputs);
}

std::optional<apply_result_t> reduce_grad_rule(
ApplyContext& ctx, CustomBackward::Maker& maker) {
auto& op = ctx.op->cast_final_safe<Reduce>();
if (op.mode == Reduce::Mode::SUM) {
if (ctx.nargs != 1) {
return {};
}
std::array<std::shared_ptr<Tensor>, 1> input_shapes;
if (input_requires_grad(ctx, 0)) {
input_shapes[0] = get_shape(ctx.args[0]);
}
maker.output_size(1).output_captured(0, false);
maker.backward([shapes = std::move(input_shapes)](
BackwardContext&, Tensor* const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(1);
if (grad && shapes[0]) {
ret[0] = broadcast_to(grad, shapes[0].get());
}
return ret;
});
return apply(ctx);
std::optional<std::vector<ValueRef>> reduce_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto& reduce = op.cast_final_safe<Reduce>();
if (reduce.mode != Reduce::Mode::SUM) {
return {};
}
if (inputs.size() != 1) {
return {};
}
std::array<ValueRef, 1> input_shapes;
if (inputs_require_grad[0]) {
input_shapes[0] = get_shape(inputs[0]);
}
return {};
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false);
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
std::vector<ValueRef> ret(1);
if (grad && shapes[0]) {
ret[0] = broadcast_to(grad, shapes[0]);
}
return ret;
});
maker.finalize();
return imperative::apply(ApplyOp(op), inputs);
}

std::optional<apply_result_t> addAxis_grad_rule(
ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<AddAxis>();
mgb_assert(ctx.nargs == 1);
bool flag = input_requires_grad(ctx, 0);
auto&& grad_op = RemoveAxis::make(op.axis);
std::optional<std::vector<ValueRef>> addAxis_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto&& addAxis = op.cast_final_safe<AddAxis>();
mgb_assert(inputs.size() == 1);
bool flag = inputs_require_grad[0];
auto&& grad_op = RemoveAxis::make(addAxis.axis);
std::sort(grad_op->axis.begin(), grad_op->axis.end(), std::greater<int32_t>());
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false);
maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](
BackwardContext&, Tensor* const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(1);
maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
std::vector<ValueRef> ret(1);
if (grad && flag_) {
ret[0] = python::apply(grad_op_, grad)[0];
ret[0] = imperative::apply(*grad_op_, grad)[0];
}
return ret;
});
return apply(ctx);
maker.finalize();
return imperative::apply(op, inputs);
}

std::optional<apply_result_t> removeAxis_grad_rule(
ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<RemoveAxis>();
mgb_assert(ctx.nargs == 1);
bool flag = input_requires_grad(ctx, 0);
auto&& grad_op = AddAxis::make(op.axis);
std::optional<std::vector<ValueRef>> removeAxis_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto&& removeAxis = op.cast_final_safe<RemoveAxis>();
mgb_assert(inputs.size() == 1);
bool flag = inputs_require_grad[0];
auto&& grad_op = AddAxis::make(removeAxis.axis);
std::sort(grad_op->axis.begin(), grad_op->axis.end());
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false);
maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](
BackwardContext&, Tensor* const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(1);
maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
std::vector<ValueRef> ret(1);
if (grad && flag_) {
ret[0] = python::apply(grad_op_, grad)[0];
ret[0] = imperative::apply(*grad_op_, grad)[0];
}
return ret;
});
return apply(ctx);
maker.finalize();
return imperative::apply(op, inputs);
}

std::optional<apply_result_t> fastpathcopy_grad_rule(
ApplyContext& ctx, CustomBackward::Maker& maker) {
mgb_assert(ctx.nargs == 1);
std::optional<std::vector<ValueRef>> fastpathcopy_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
mgb_assert(inputs.size() == 1);
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false);
maker.backward([](BackwardContext&, Tensor* const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(1);
maker.backward([](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
std::vector<ValueRef> ret(1);
if (grad) {
ret[0] = grad->shared_from_this();
ret[0] = grad;
}
return ret;
});
return apply(ctx);
maker.finalize();
return imperative::apply(op, inputs);
}

struct Init {
Init() {
auto& reg = grad_rule_registry();
reg.emplace(Elemwise::typeinfo(), elemwise_grad_rule);
reg.emplace(Reshape::typeinfo(), reshape_grad_rule);
reg.emplace(Subtensor::typeinfo(), subtensor_grad_rule);
reg.emplace(IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule);
reg.emplace(Reduce::typeinfo(), reduce_grad_rule);
reg.emplace(AddAxis::typeinfo(), addAxis_grad_rule);
reg.emplace(RemoveAxis::typeinfo(), removeAxis_grad_rule);
reg.emplace(FastpathCopy::typeinfo(), fastpathcopy_grad_rule);
CustomBackward::register_grad_rule(Elemwise::typeinfo(), elemwise_grad_rule);
CustomBackward::register_grad_rule(Reshape::typeinfo(), reshape_grad_rule);
CustomBackward::register_grad_rule(Subtensor::typeinfo(), subtensor_grad_rule);
CustomBackward::register_grad_rule(
IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule);
CustomBackward::register_grad_rule(Reduce::typeinfo(), reduce_grad_rule);
CustomBackward::register_grad_rule(AddAxis::typeinfo(), addAxis_grad_rule);
CustomBackward::register_grad_rule(
RemoveAxis::typeinfo(), removeAxis_grad_rule);
CustomBackward::register_grad_rule(
FastpathCopy::typeinfo(), fastpathcopy_grad_rule);
}
} _;



+ 0
- 245
imperative/python/src/intrusive_list.h View File

@@ -1,245 +0,0 @@
/**
* \file imperative/python/src/intrusive_list.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "megbrain/utils/metahelper.h"

namespace mgb::imperative::python::intrusive_list {

// copy policy
struct after_t {};
struct before_t {};
struct disable_t {};

template <typename T>
struct Tail;

// invariant: next->prev == this
template <typename T>
struct Head {
Tail<T>* next;

Head(Tail<T>* node = nullptr) : next(node) {}
Head(const Head<T>&) = delete;
Head<T>& operator=(const Head<T>&) = delete;
Head(Head<T>&& rhs) : next(rhs.next) {
rhs.next = nullptr;
if (next) {
next->prev = this;
}
}
Head<T>& operator=(Head<T>&& rhs) {
mgb_assert(!next);
next = rhs.next;
rhs.next = nullptr;
if (next) {
next->prev = this;
}
return *this;
}
~Head() {
if (next) {
next->prev = nullptr;
}
}
};

// invariant: prev->next == this
template <typename T>
struct Tail {
Head<T>* prev;

Tail(Head<T>* node = nullptr) : prev(node) {}
Tail(const Tail<T>&) = delete;
Tail<T>& operator=(const Tail<T>&) = delete;
Tail(Tail<T>&& rhs) : prev(rhs.prev) {
rhs.prev = nullptr;
if (prev) {
prev->next = this;
}
}
Tail<T>& operator=(Tail<T>&& rhs) {
mgb_assert(!prev);
prev = rhs.prev;
rhs.prev = nullptr;
if (prev) {
prev->next = this;
}
return *this;
}
~Tail() {
if (prev) {
prev->next = nullptr;
}
}
};

template <typename T, typename policy>
struct Node;

template <typename T>
class Iterator {
T* ptr;

void inc() { ptr = static_cast<T*>(ptr->Head<T>::next); }
void dec() { ptr = static_cast<T*>(ptr->Head<T>::prev); }

public:
Iterator(Head<T>& head) : ptr(static_cast<T*>(head.next)) {}
Iterator(Tail<T>& tail) : ptr(static_cast<T*>(tail.prev)) {}

template <typename policy>
Iterator(Node<T, policy>& node) : ptr(static_cast<T*>(&node)) {}

T& operator*() { return *static_cast<T*>(ptr); }
T* operator->() { return static_cast<T*>(ptr); }

operator bool() { return ptr; }
bool operator==(const Iterator<T>& rhs) { return ptr == rhs.ptr; }

Iterator& operator++() {
inc();
return *this;
}
Iterator& operator--() {
dec();
return *this;
}
Iterator operator++(int) {
auto ret = *this;
inc();
return ret;
}
Iterator operator--(int) {
auto ret = *this;
dec();
return ret;
}
};

// Node in a doubly linked list. Unlike std::list, nodes are not owned by a container.
// Instead, nodes may join or leave a list freely.
// NOTE: Derived classes have to explicitly declare copy / assignment as default,
// otherwise the compiler generated version would use the const T& signature,
// which is deleted.
template <typename T = void, typename policy = disable_t>
struct Node : Tail<std::conditional_t<std::is_same_v<T, void>, Node<T, policy>, T>>,
Head<std::conditional_t<std::is_same_v<T, void>, Node<T, policy>, T>> {
private:
using this_t = Node<T, policy>;
using U = std::conditional_t<std::is_same_v<T, void>, this_t, T>;

public:
using head_t = Head<U>;
using tail_t = Tail<U>;
using head_t::next;
using tail_t::prev;

Node() = default;
Node(const this_t&) = delete;
this_t& operator=(const this_t&) = delete;

//! constructed node is inserted after the input node
Node(after_t, head_t& node) : tail_t(&node), head_t(node.next) {
node.next = this;
if (next) {
next->prev = this;
}
}

//! constructed node is inserted before the input node
Node(before_t, tail_t& node) : head_t(&node), tail_t(node.prev) {
node.prev = this;
if (prev) {
prev->next = this;
}
}

Node(this_t&& rhs) : tail_t(rhs.prev), head_t(rhs.next) {
rhs.prev = nullptr;
rhs.next = nullptr;
if (prev) {
prev->next = this;
}
if (next) {
next->prev = this;
}
}

Node& operator=(this_t&& rhs) {
unlink();
prev = rhs.prev;
next = rhs.next;
rhs.prev = nullptr;
rhs.next = nullptr;
if (prev) {
prev->next = this;
}
if (next) {
next->prev = this;
}
return *this;
}

template <
typename p = policy,
typename = std::enable_if_t<
std::is_same_v<p, before_t> || std::is_same_v<p, after_t>, void>>
Node(this_t& rhs) : Node(policy{}, rhs) {}

template <
typename p = policy,
typename = std::enable_if_t<
std::is_same_v<p, before_t> || std::is_same_v<p, after_t>, void>>
this_t& operator=(this_t& rhs) {
insert(policy{}, rhs);
return *this;
}

void unlink() {
if (prev) {
prev->next = next;
}
if (next) {
next->prev = prev;
}
prev = nullptr;
next = nullptr;
}

//! this node is unlinked from its list and inserted after the input node
void insert(after_t, head_t& node) {
unlink();
prev = &node;
next = node.next;
node.next = this;
if (next) {
next->prev = this;
}
}

//! this node is unlinked from its list and inserted before the input node
void insert(before_t, tail_t& node) {
unlink();
next = &node;
prev = node.prev;
node.prev = this;
if (prev) {
prev->next = this;
}
}

void insert_before(tail_t& node) { insert(before_t{}, node); }
void insert_after(head_t& node) { insert(after_t{}, node); }

~Node() { unlink(); }
};

} // namespace mgb::imperative::python::intrusive_list

+ 0
- 42
imperative/python/src/module_trace.cpp View File

@@ -1,42 +0,0 @@
/**
* \file imperative/python/src/module_trace.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "./module_trace.h"
#include "./helper.h" // include op pybind11 caster

namespace py = pybind11;

namespace mgb::imperative::python {

apply_result_t apply_module_trace(ApplyContext& ctx) {
apply_result_t outputs;

auto args = py::tuple(ctx.nargs + 1);
args[0] = py::cast(ctx.op);
for (size_t i = 0; i < ctx.nargs; i++) {
args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this());
}
auto pyout = PyObject_Call(cpp_apply_module_trace, args.ptr(), nullptr);
if (!pyout)
throw py::error_already_set();
auto ret = py::reinterpret_steal<py::object>(pyout);

// assumption: python function always returns PyList
auto tup = py::reinterpret_borrow<py::list>(ret);
for (auto i = 0; i < tup.size(); i++) {
auto tw = TensorWrapper::try_cast(tup[i].ptr());
outputs.emplace_back(tw->m_tensor);
}
return outputs;
}

} // namespace mgb::imperative::python

+ 41
- 1
imperative/python/src/module_trace.h View File

@@ -11,10 +11,50 @@

#pragma once

#include "megbrain/imperative/transformations/trace.h"
#include "megbrain/imperative/utils/map.h"

#include "./tensor.h"

namespace mgb::imperative::python {

apply_result_t apply_module_trace(ApplyContext& ctx);
namespace py = pybind11;

class ModuleTraceTransformation final : public Transformation {
private:
py::function m_hook_fn;
int m_enabled = 0;

std::vector<ValueRef> apply_module_trace_hook(
const OpDef& op, Span<ValueRef> input_values) {
py::list input_tws;
for (auto&& input_value : input_values) {
input_tws.append(TensorWrapper::make(py_tensor_type, input_value));
}
py::list output_tws = m_hook_fn(py::cast(op.shared_from_this()), *input_tws);
std::vector<ValueRef> outputs;
for (auto&& output_tw : output_tws) {
outputs.push_back(
TensorWrapper::try_cast(output_tw.ptr())->m_tensor->data());
}
return outputs;
}

public:
ModuleTraceTransformation(py::function hook_fn) : m_hook_fn(hook_fn) {}
std::vector<ValueRef> apply_transformation(
const Operator& op, Span<ValueRef> inputs) override {
if (op.is<ApplyOp>() && m_enabled > 0) {
auto outputs = apply_module_trace_hook(op.cast<ApplyOp>().op(), inputs);
return outputs;
} else {
return imperative::apply(op, inputs);
}
}

ValueRef unwrap(ValueRef value) override { return value; }

std::string name() const override { return "ModuleTraceTransformation"; }
};

} // namespace mgb::imperative::python

+ 2
- 7
imperative/python/src/ops.cpp View File

@@ -185,7 +185,8 @@ int py_set_scope(PyObject* obj, PyObject* value, void* /* closure */) {
}

PyGetSetDef PyOp(OpDef)::py_getsetters[] = {
{const_cast<char*>("scope"), py_get_scope, py_set_scope, "scope", NULL},
{const_cast<char*>("scope"), py_get_scope, py_set_scope,
const_cast<char*>("scope"), NULL},
{NULL}};

Py_hash_t PyOp(OpDef)::tp_hash(PyObject* obj) {
@@ -556,12 +557,6 @@ void init_ops(py::module m) {
m.def(
"delete_rng_handle",
[](size_t handle) {
// RNG op might execute after handle released due to async dispatch, so
// we need sync before delete a handle to avoid memory leak or
// use-after-free
if (python::interpreter_for_py->check_available()) {
python::interpreter_for_py->sync();
}
mgb::CompNode::sync_all();
py_task_q.wait_all_task_finish();
rng::delete_handle(handle);


+ 493
- 541
imperative/python/src/tensor.cpp
File diff suppressed because it is too large
View File


+ 38
- 216
imperative/python/src/tensor.h View File

@@ -20,6 +20,8 @@
#include "pybind11/pybind11.h"

#include "./pyext17.h"
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/utils/span.h"

namespace mgb::imperative::python {

@@ -32,126 +34,67 @@ struct ObjectPtr : B {

} // namespace mgb::imperative::python

#include "./grad_info.h" // for struct GradInfo
#include "./trace_info.h" // for struct TraceInfo

namespace mgb::imperative::python {

struct GradKey;

extern interpreter::Interpreter::Channel* interpreter_for_py;
extern PyTypeObject* py_tensor_type;

class SharedHandle {
using Handle = interpreter::Interpreter::Handle;
static_assert(std::is_pointer_v<Handle>);
std::shared_ptr<std::remove_pointer_t<Handle>> holder;

public:
inline explicit SharedHandle(Handle handle)
: holder(handle, [](auto* h) {
if (h) {
interpreter_for_py->del(h);
}
}) {}
SharedHandle(const SharedHandle&) = default;
SharedHandle& operator=(const SharedHandle&) = default;
SharedHandle(SharedHandle&&) = default;
SharedHandle& operator=(SharedHandle&&) = default;

inline Handle get() { return holder.get(); }
};

// impl in grad.cpp
class GradInfoCollection {
struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
private:
SmallVector<GradInfo> m_storage;

protected:
void _shrink();
std::string m_name;
ValueRef m_data;

public:
bool contains(GradKey* key);
GradInfo& operator[](GradKey* key);
GradInfo& at(GradKey* key);
bool empty() {
_shrink();
return m_storage.empty();
}
auto begin() {
_shrink();
return m_storage.begin();
}
auto end() {
_shrink();
return m_storage.end();
}
size_t count(GradKey* key) { return contains(key) ? 1 : 0; }
};

struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
using flags_t = uint64_t;

struct Flags {
static constexpr flags_t SCALAR = 1;
static constexpr flags_t GRAD = 1 << 1;
static constexpr flags_t TRACE = 1 << 2;
static constexpr flags_t MODULE_TRACE = 1 << 3;
};

flags_t m_flags = 0;

GradInfoCollection m_grad_info_dict;
TraceInfo m_trace_info;
SharedHandle m_handle;
std::string user_custom_name;
std::string automatic_name;
cg::VarNode* m_var;
pybind11::object m_module_trace_info;

using Handle = interpreter::Interpreter::Handle;

inline Tensor() : m_handle(nullptr), m_var(nullptr) {}
inline explicit Tensor(Handle handle) : m_handle(handle), m_var(nullptr) {}
inline explicit Tensor(SharedHandle handle)
: m_handle(std::move(handle)), m_var(nullptr) {}
inline explicit Tensor(cg::VarNode* var) : m_handle(nullptr), m_var(var) {}
inline explicit Tensor(ValueRef data) : m_data{data} {}

~Tensor() = default;

inline std::shared_ptr<Tensor> copy() {
auto ret = std::make_shared<Tensor>(m_handle);
ret->m_flags = m_flags;
ret->m_grad_info_dict = m_grad_info_dict;
ret->m_trace_info = m_trace_info;
ret->m_var = m_var;
auto ret = std::make_shared<Tensor>(m_data.unwrap());
ret->m_name = m_name;
return ret;
}

inline DType dtype() {
if (m_var) {
return m_var->dtype();
inline DType dtype() { return *data().dtype(); }
inline CompNode comp_node() { return *data().device(); }
inline std::optional<ValueShape> shape() {
auto shape = data().shape();
if (!shape) {
return {};
}
return interpreter_for_py->get_dtype(m_handle.get());
return *shape;
}
inline CompNode comp_node() {
if (m_var) {
return m_var->comp_node();
inline HostValue::ref_t numpy() { return data().numpy(); }
inline void reset(ValueRef value) {
m_data = value;
if (!m_name.empty()) {
set_name(m_name);
}
return interpreter_for_py->get_device(m_handle.get());
}
inline TensorShape shape() {
if (m_var) {
return m_var->shape();
inline ValueRef data() { return m_data.unwrap(); }
bool is_scalar() { return data().is_scalar(); }
inline std::string name() { return m_name; }
inline void set_name(std::string name) {
m_name = name;
if (!name.empty()) {
auto output = imperative::apply(RenameValue(name), m_data)[0];
m_data = output;
}
return interpreter_for_py->get_shape(m_handle.get());
}
};

struct TensorWrapper {
public:
std::shared_ptr<Tensor> m_tensor;

inline TensorWrapper(std::shared_ptr<Tensor> tensor = {})
: m_tensor(std::move(tensor)) {}
: m_tensor(std::move(tensor)) {
mgb_assert(tensor, "empty storage");
}

inline TensorWrapper(ValueRef value) : m_tensor(std::make_shared<Tensor>(value)) {}
TensorWrapper(PyObject* args, PyObject* kwargs);
~TensorWrapper() = default;

@@ -191,33 +134,17 @@ struct TensorWrapper {
void reset(PyObject*);
PyObject* detach();
PyObject* isscalar();
void setscalar();
void unsetscalar();
PyObject* _dev_tensor();
void _drop();
PyObject* varnode();
void reset_varnode();
PyObject* handle();
void set_handle(PyObject*);

PyObject* mixin_handle();
PyObject* recording();
PyObject* copied();

void set_mixin_handle(PyObject*);
void set_recording(PyObject*);

PyObject* compiled_info();
void set_compiled_info(PyObject*);
PyObject* trace_mixin_info();
void set_trace_mixin_info(PyObject*);
PyObject* module_trace_info();
void set_module_trace_info(PyObject*);
PyObject* user_custom_name();
void set_user_custom_name(PyObject*);
PyObject* automatic_name();
void set_automatic_name(PyObject*);
void _set_name(PyObject*);
PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); };
PyObject* _detail();
void _watch();
};

struct PySymbolVar {
@@ -230,113 +157,8 @@ struct PySymbolVar {
PyObject* py_apply(
PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */);

struct ApplyContext {
static Tensor::flags_t global_disable;
static Tensor::flags_t global_enable;

Tensor::flags_t flags = 0;
std::shared_ptr<OpDef> op;
Tensor* const* args;
size_t nargs;
PyTypeObject* pytype = nullptr;
bool backward = false;

class scoped_disable : NonCopyableObj {
Tensor::flags_t saved_flags;

public:
scoped_disable(Tensor::flags_t flags)
: saved_flags(ApplyContext::global_disable) {
ApplyContext::global_disable |= flags;
}
~scoped_disable() { ApplyContext::global_disable = saved_flags; }
};
};

using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>;

apply_result_t apply(ApplyContext& ctx);

template <typename T>
decltype(auto) resolve_arrow(T&& p) {
if constexpr (std::is_pointer_v<std::remove_reference_t<T>>) {
auto* ret = p;
return ret;
} else {
auto probe = [](auto&& p) -> decltype(p.operator->()) {};
if constexpr (std::is_invocable_v<decltype(probe), decltype(p)>) {
return resolve_arrow(p.operator->());
} else {
return std::forward<T>(p);
}
}
}

template <typename... Args>
constexpr bool is_all_tensor_ptr =
(... && std::is_same_v<decltype(resolve_arrow(std::declval<Args>())), Tensor*>);

template <typename... Args, std::enable_if_t<is_all_tensor_ptr<Args...>, int> = 0>
apply_result_t apply(std::shared_ptr<OpDef> op, Args&&... args) {
ApplyContext ctx;
Tensor* arg_arr[] = {resolve_arrow(args)...};
ctx.flags = (0 | ... | args->m_flags);
ctx.args = arg_arr;
ctx.nargs = sizeof...(args);
ctx.op = std::move(op);
return apply(ctx);
}

inline auto apply(std::shared_ptr<OpDef> op, Tensor* const* args, size_t nargs) {
ApplyContext ctx;
ctx.op = std::move(op);
ctx.nargs = nargs;
ctx.args = args;
for (size_t i = 0; i < nargs; ++i) {
ctx.flags |= args[i]->m_flags;
}
return apply(ctx);
}

template <typename T>
auto apply(std::shared_ptr<OpDef> op, T&& tensors) -> std::enable_if_t<
std::is_same_v<decltype(resolve_arrow(tensors[0])), Tensor*>, apply_result_t> {
size_t nargs = tensors.size();
Tensor* args[nargs];
for (size_t i = 0; i < nargs; ++i) {
args[i] = resolve_arrow(tensors[i]);
}
return apply(op, args, nargs);
}

std::shared_ptr<Tensor> make_const(imperative::TensorPtr value);

inline auto apply(Subgraph graph, Tensor* const* args, size_t nargs) {
SmallVector<std::shared_ptr<Tensor>> inputs;
for (size_t i = 0; i < nargs; ++i) {
inputs.push_back(args[i]->shared_from_this());
}
auto apply_functor = [](std::shared_ptr<OpDef> op,
SmallVector<std::shared_ptr<Tensor>> inputs,
size_t) { return apply(op, std::move(inputs)); };
return graph.apply(inputs, apply_functor, &make_const);
}

template <typename T>
auto apply(Subgraph graph, T&& tensors) -> std::enable_if_t<
std::is_same_v<std::decay_t<decltype(tensors[0])>, Tensor*>, apply_result_t> {
size_t nargs = tensors.size();
Tensor* args[nargs];
for (size_t i = 0; i < nargs; ++i) {
args[i] = resolve_arrow(tensors[i]);
}
return apply(graph, args, nargs);
}

void init_tensor(pybind11::module);

extern PyObject* cpp_apply_with_tracing;
extern PyObject* cpp_apply_backward_varnode;
extern PyObject* cpp_apply_module_trace;

} // namespace mgb::imperative::python


+ 0
- 63
imperative/python/src/trace.cpp View File

@@ -1,63 +0,0 @@
/**
* \file imperative/python/src/trace.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "./trace.h"
#include "./helper.h"
#include "megbrain/imperative/ops/autogen.h"

namespace py = pybind11;

namespace mgb::imperative::python {

apply_result_t apply_trace(ApplyContext& ctx) {
apply_result_t outputs;

if (ctx.backward) {
// reach here when compiled=True
auto args = py::tuple(ctx.nargs + 1);
args[0] = py::cast(ctx.op);
for (size_t i = 0; i < ctx.nargs; i++) {
args[i + 1] = py::cast(ctx.args[i]->m_var);
}
py::object pyout = py::reinterpret_steal<py::object>(
PyObject_Call(cpp_apply_backward_varnode, args.ptr(), nullptr));
if (!pyout)
throw py::error_already_set();

// assumption: python function always returns PyList
auto tup = py::reinterpret_borrow<py::list>(pyout);
for (size_t i = 0; i < tup.size(); i++) {
auto pitem = tup[i].cast<cg::VarNode*>();
outputs.emplace_back(std::make_shared<Tensor>(pitem));
}
return outputs;
}

auto args = py::tuple(ctx.nargs + 1);
args[0] = py::cast(ctx.op);
for (size_t i = 0; i < ctx.nargs; i++) {
args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this());
}
auto pyout = PyObject_Call(cpp_apply_with_tracing, args.ptr(), nullptr);
if (!pyout)
throw py::error_already_set();

// assumption: python function always returns PyList
auto tup = py::reinterpret_steal<py::list>(pyout);
for (size_t i = 0; i < tup.size(); i++) {
auto tw = TensorWrapper::try_cast(tup[i].ptr());
outputs.emplace_back(tw->m_tensor);
}
return outputs;
}

} // namespace mgb::imperative::python

+ 0
- 28
imperative/python/src/trace.h View File

@@ -1,28 +0,0 @@
/**
* \file imperative/python/src/trace.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include <stdexcept>
#include "./tensor.h"

namespace mgb::imperative::python {

class TraceReadError : public std::exception {
public:
explicit TraceReadError(const char* m) : message{m} {}
const char* what() const noexcept override { return message.c_str(); }

private:
std::string message = "";
};

apply_result_t apply_trace(ApplyContext& ctx);

} // namespace mgb::imperative::python

+ 0
- 49
imperative/python/src/trace_info.h View File

@@ -1,49 +0,0 @@
/**
* \file imperative/python/src/trace_info.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "Python.h"
#include "inttypes.h"

namespace mgb::imperative::python {

struct TraceInfo {
int64_t mixin_handle = -1;
bool recording = false;

// refer to CompiledTensorProxy in tracing.py, works from second trace step
PyObject* compiled_info = nullptr;
// refer to TensorInfo in tracing.py, only works in first trace step
PyObject* trace_mixin_info = nullptr;

TraceInfo() = default;

TraceInfo& operator=(const TraceInfo& that) {
mixin_handle = that.mixin_handle;
recording = that.recording;

trace_mixin_info = that.trace_mixin_info;
Py_XINCREF(trace_mixin_info);
compiled_info = that.compiled_info;
Py_XINCREF(compiled_info);

return *this;
}

~TraceInfo() {
Py_XDECREF(trace_mixin_info);
Py_XDECREF(compiled_info);
}

private:
TraceInfo(const TraceInfo& that) = default;
};

} // namespace mgb::imperative::python

+ 0
- 14
imperative/python/test/conftest.py View File

@@ -16,10 +16,6 @@ import megengine.module
from megengine import Parameter
from megengine.core._imperative_rt.core2 import sync
from megengine.device import get_device_count
from megengine.experimental.autograd import (
disable_higher_order_directive,
enable_higher_order_directive,
)
from megengine.jit import trace as _trace
from megengine.module import Linear, Module

@@ -45,13 +41,3 @@ def skip_distributed(request):
platform.system()
)
)


@pytest.fixture(autouse=True)
def resolve_require_higher_order_directive(request):
marker = request.node.get_closest_marker("require_higher_order_directive")
if marker:
enable_higher_order_directive()
yield
if marker:
disable_higher_order_directive()

+ 1
- 1
imperative/python/test/integration/test_trace_dump.py View File

@@ -146,5 +146,5 @@ def test_dump_bn_train_mode():

data = mge.tensor(np.random.random((10, 10, 10, 10)))
bn_train(data)
with pytest.raises(AssertionError):
with pytest.raises(RuntimeError):
bn_train.dump("test.mge")

+ 43
- 8
imperative/python/test/unit/autodiff/test_grad_manger.py View File

@@ -17,7 +17,7 @@ import megengine.distributed as dist
import megengine.functional as F
import megengine.module as M
import megengine.optimizer as optim
from megengine.autodiff import GradManager
from megengine.autodiff import Function, GradManager
from megengine.jit import trace


@@ -214,7 +214,7 @@ def test_remote_grad(trace_mode):
x = dist.functional.remote_recv(rank - 1)
y = m(x)
if rank != size - 1:
dist.functional.remote_send(y, dest_rank=rank + 1)
x = dist.functional.remote_send(y, dest_rank=rank + 1)
gm.backward()
else:
y = y.mean()
@@ -224,7 +224,7 @@ def test_remote_grad(trace_mode):
if trace_mode is not None:
train_func = trace(symbolic=trace_mode)(train_func)

for i in range(3):
for i in range(1):
train_func(x)

worker()
@@ -340,7 +340,6 @@ def test_broadcast_grad(trace_mode):
worker()


@pytest.mark.require_higher_order_directive()
def test_2nd_grad_with_manager():
x_np = np.random.rand(10).astype("float32")
x = mge.tensor(x_np)
@@ -359,7 +358,6 @@ def test_2nd_grad_with_manager():
)


@pytest.mark.require_higher_order_directive()
def test_grad_manager_group():
x_np = np.random.rand(10).astype("float32")
x = mge.tensor(x_np)
@@ -376,7 +374,6 @@ def test_grad_manager_group():
x.grad = None


@pytest.mark.require_higher_order_directive()
def test_grad_manager_group_visibility():
x_np = np.random.rand(10).astype("float32")
x = mge.tensor(x_np)
@@ -392,7 +389,6 @@ def test_grad_manager_group_visibility():
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5)


@pytest.mark.require_higher_order_directive()
def test_grad_manager_visibility_by_order():
x_np = np.random.rand(10).astype("float32")
x = mge.tensor(x_np)
@@ -410,7 +406,6 @@ def test_grad_manager_visibility_by_order():
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5)


@pytest.mark.require_higher_order_directive()
@pytest.mark.parametrize("target", [F.cos, F.sin, lambda x: x * 2 + 1])
def test_emulate_forward_mode_with_reverse_mode(target):
def jvp(inp, expr):
@@ -434,3 +429,43 @@ def test_emulate_forward_mode_with_reverse_mode(target):

np.testing.assert_almost_equal(y.numpy(), y1.numpy(), decimal=5)
np.testing.assert_almost_equal(dy.numpy(), dy1.numpy(), decimal=3)


def test_2nd_grad_with_custom_gradient():
class MySin(Function):
def forward(self, x):
self.inp = x
x = mge.Tensor(x.numpy())
y = F.sin(x)
return y

def backward(self, dy):
dx = F.cos(self.inp) * dy
return dx

class MyCos(Function):
def forward(self, x):
self.inp = x
x = mge.Tensor(x.numpy())
y = F.cos(x)
return y

def backward(self, dy):
dx = -MySin()(self.inp) * dy
return dx

x_np = np.random.rand(10).astype("float32")
x = mge.tensor(x_np)

gm = GradManager().attach([x])
gm2 = GradManager().attach([x])

with gm:
with gm2:
y = MyCos()(x)
gm2.backward(y)
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5)
gm.backward(x.grad)
np.testing.assert_almost_equal(
x.grad.numpy(), -np.sin(x_np) - np.cos(x_np), decimal=5
)

+ 162
- 161
imperative/python/test/unit/core/test_autodiff.py View File

@@ -7,8 +7,6 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import gc
import platform
import weakref

import numpy as np
import pytest
@@ -60,24 +58,20 @@ def test_dist_grad():
def worker():
rank = dist.get_rank()
if rank == 0:
grad = Grad()

x = as_tensor(x_np)
grad.wrt(x, callback=save_to(x))
# need a placeholder to trace operator
remote_send(x, 1)
recv_x = remote_recv(1)
y = recv_x * recv_x

grad([y], [as_tensor(np.ones_like(x_np))])
with Grad() as grad:
x = as_tensor(x_np)
grad.wrt(x, callback=save_to(x))
# need a placeholder to trace operator
remote_send(x, 1)
recv_x = remote_recv(1)
y = recv_x * recv_x
grad([y], [as_tensor(np.ones_like(x_np))])
np.testing.assert_almost_equal(x.grad.numpy(), x.numpy() * 2)
elif rank == 1:
grad = Grad()

recv_x = remote_recv(0)
remote_send(recv_x, 0)

grad([], [])
with Grad() as grad:
recv_x = remote_recv(0)
remote_send(recv_x, 0)
grad([], [])

worker()

@@ -86,11 +80,11 @@ def test_grad():
x_np = np.random.rand(10).astype("float32")
x = as_tensor(x_np)

grad = Grad().wrt(x, callback=save_to(x))

y = cos(x)
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
y = cos(x)
grad(y, as_tensor(np.ones_like(x_np)))

grad(y, as_tensor(np.ones_like(x_np)))
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np))


@@ -98,12 +92,12 @@ def test_grad_2():
x_np = np.random.rand(10).astype("float32")
x = as_tensor(x_np)

grad = Grad().wrt(x, callback=save_to(x))
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
y = mul(x, x)
y = mul(y, y)
grad(y, as_tensor(np.ones_like(x_np)))

y = mul(x, x)
y = mul(y, y)

grad(y, as_tensor(np.ones_like(x_np)))
np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6)


@@ -113,32 +107,31 @@ def test_2nd_grad():
x = as_tensor(x_np)
ones = as_tensor(np.ones_like(x_np))

grad = Grad().wrt(x, callback=save_to(x))
grad._priority = -1
grad2 = Grad().wrt(x, callback=save_to(x))
grad2._priority = 0

y = cos(x)
with Grad("grad2") as grad2:
with Grad("grad") as grad:
grad2.wrt(x, callback=save_to(x))
grad.wrt(x, callback=save_to(x))
y = cos(x)
grad(y, ones)
z = x.grad
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5)

grad(y, ones)
z = x.grad
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5)
x.grad = None
grad2(z, ones)

x.grad = None
grad2(z, ones)
np.testing.assert_almost_equal(x.grad.numpy(), -np.cos(x_np), decimal=5)
np.testing.assert_almost_equal(x.grad.numpy(), -np.cos(x_np), decimal=5)


def test_grad_with_tensor_wrapper():
x_np = np.random.rand(10).astype("float32")
x = mge.Tensor(x_np)

grad = Grad().wrt(x, callback=save_to(x))
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
y = mul(x, x)
y = mul(y, y)
grad(y, mge.Tensor(np.ones_like(x_np)))

y = mul(x, x)
y = mul(y, y)

grad(y, mge.Tensor(np.ones_like(x_np)))
np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6)


@@ -162,18 +155,21 @@ def test_release():

@check
def _():
g = Grad().wrt(x)
y = x * x
g(y, dy)
with Grad() as g:
g.wrt(x)
y = x * x
g(y, dy)

@check
def _():
with Grad().wrt(x):
with Grad() as g:
g.wrt(x)
pass

@check
def _():
with Grad().wrt(x):
with Grad() as g:
g.wrt(x)
y = x * x


@@ -181,12 +177,12 @@ def test_grad_inplace():
x_np = np.random.rand(10).astype("float32")
x = mge.Tensor(x_np)

grad = Grad().wrt(x, callback=save_to(x))
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
y = mul(x, x)
y *= y
grad(y, mge.Tensor(np.ones_like(x_np)))

y = mul(x, x)
y *= y

grad(y, mge.Tensor(np.ones_like(x_np)))
np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6)


@@ -196,11 +192,11 @@ def test_identity():
dy_np = np.random.rand(*x.shape).astype("float32")
dy = mge.Tensor(dy_np)

grad = Grad().wrt(x, callback=save_to(x))

(y,) = apply(Identity(), x)
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
(y,) = apply(Identity(), x)
grad(y, dy)

grad(y, dy)
np.testing.assert_array_equal(x.grad.numpy(), dy_np)


@@ -220,15 +216,14 @@ def test_elemwise_add():
refs["y"] = TensorWeakRef(y)
return x + y

grad = Grad().wrt(x, callback=save_to(x))

z = f(x, y)
del y
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
z = f(x, y)
del y
for k, r in refs.items():
assert r() is None
grad(z, dz)

for k, r in refs.items():
assert r() is None

grad(z, dz)
np.testing.assert_almost_equal(x.grad.numpy(), dz_np.sum(0) * 2, decimal=5)


@@ -245,13 +240,12 @@ def test_elemwise_relu():
refs["x"] = TensorWeakRef(x)
return relu(x)

grad = Grad().wrt(x, callback=save_to(x))
z = f(x)
assert refs["x"]() is None
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
z = f(x)
assert refs["x"]() is None
grad(z, dz)

grad(z, dz)
np.testing.assert_almost_equal(x.grad.numpy(), [2.0, 0])


@@ -269,21 +263,21 @@ def test_reshape():
x_np = np.random.rand(2, 5).astype("float32")
x = mge.Tensor(x_np)

grad = Grad().wrt(x, callback=save_to(x))
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
refs = {}

refs = {}
def f(x):
x = x * 1
y = x.reshape(5, 2)
refs["x"] = TensorWeakRef(x)
return y

def f(x):
x = x * 1
y = x.reshape(5, 2)
refs["x"] = TensorWeakRef(x)
return y
y = f(x)
for _, r in refs.items():
assert r() is None
grad(y, F.ones_like(y))

y = f(x)
for _, r in refs.items():
assert r() is None

grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((2, 5), dtype=np.float32), x.grad.numpy())


@@ -291,21 +285,21 @@ def test_subtensor():
x_np = np.random.rand(3, 3).astype("float32")
x = mge.Tensor(x_np)

grad = Grad().wrt(x, callback=save_to(x))
refs = {}
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
refs = {}

def f(x):
x = x * 1
y = x[1:-1, :2]
refs["x"] = TensorWeakRef(x)
return y
def f(x):
x = x * 1
y = x[1:-1, :2]
refs["x"] = TensorWeakRef(x)
return y

y = f(x)
for _, r in refs.items():
assert r() is None
y = f(x)
for _, r in refs.items():
assert r() is None
grad(y, F.ones_like(y))

grad(y, F.ones_like(y))
np.testing.assert_equal(
np.array([[0, 0, 0], [1, 1, 0], [0, 0, 0]], dtype=np.float32), x.grad.numpy()
)
@@ -315,21 +309,21 @@ def test_IndexingMultiAxisVec():
x_np = np.random.rand(3, 3).astype("float32")
x = mge.Tensor(x_np)

grad = Grad().wrt(x, callback=save_to(x))
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
refs = {}

refs = {}
def f(x):
x = x * 1
y = x[[0, 2], [0, 2]]
refs["x"] = TensorWeakRef(x)
return y

def f(x):
x = x * 1
y = x[[0, 2], [0, 2]]
refs["x"] = TensorWeakRef(x)
return y
y = f(x)
for _, r in refs.items():
assert r() is None
grad(y, F.ones_like(y))

y = f(x)
for _, r in refs.items():
assert r() is None

grad(y, F.ones_like(y))
np.testing.assert_equal(
np.array([[1, 0, 0], [0, 0, 0], [0, 0, 1]], dtype=np.float32), x.grad.numpy()
)
@@ -339,21 +333,21 @@ def test_AxisAddRemove():
x_np = np.random.rand(1, 5).astype("float32")
x = mge.Tensor(x_np)

grad = Grad().wrt(x, callback=save_to(x))
refs = {}
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
refs = {}

def f(x):
x = x * 1
y = F.squeeze(F.expand_dims(x, 2), 0)
refs["x"] = TensorWeakRef(x)
return y
def f(x):
x = x * 1
y = F.squeeze(F.expand_dims(x, 2), 0)
refs["x"] = TensorWeakRef(x)
return y

y = f(x)
for _, r in refs.items():
assert r() is None
y = f(x)
for _, r in refs.items():
assert r() is None
grad(y, F.ones_like(y))

grad(y, F.ones_like(y))
np.testing.assert_equal(
np.array([[1, 1, 1, 1, 1]], dtype=np.float32), x.grad.numpy()
)
@@ -363,10 +357,11 @@ def test_Broadcast():
x_np = np.random.rand(3, 3, 1).astype("float32")
x = mge.Tensor(x_np)

grad = Grad().wrt(x, callback=save_to(x))
y = F.broadcast_to(x, (3, 3, 10))
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
y = F.broadcast_to(x, (3, 3, 10))
grad(y, F.ones_like(y))

grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((3, 3, 1), dtype=np.float32) * 10, x.grad.numpy())


@@ -374,10 +369,11 @@ def test_interpolate_fastpath():
x_np = np.random.rand(3, 3, 32, 32).astype("float32")
x = mge.Tensor(x_np)

grad = Grad().wrt(x, callback=save_to(x))
y = F.vision.interpolate(x, size=(16, 16), mode="bilinear")
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
y = F.vision.interpolate(x, size=(16, 16), mode="bilinear")
grad(y, F.ones_like(y))

grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones(x_np.shape, dtype=np.float32) / 4, x.grad.numpy())


@@ -385,10 +381,11 @@ def test_Reduce_sum():
x_np = np.random.rand(3, 3).astype("float32")
x = mge.Tensor(x_np)

grad = Grad().wrt(x, callback=save_to(x))
y = x.sum(axis=0)
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
y = x.sum(axis=0)
grad(y, F.ones_like(y))

grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((3, 3), dtype=np.float32), x.grad.numpy())


@@ -396,10 +393,11 @@ def test_Reduce_mean():
x_np = np.random.rand(3, 3).astype("float32")
x = mge.Tensor(x_np)

grad = Grad().wrt(x, callback=save_to(x))
y = x.mean(axis=0)
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
y = x.mean(axis=0)
grad(y, F.ones_like(y))

grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((3, 3), dtype=np.float32) / 3, x.grad.numpy())


@@ -407,21 +405,21 @@ def test_addAxis():
x_np = np.random.rand(3, 3).astype("float32")
x = mge.Tensor(x_np)

grad = Grad().wrt(x, callback=save_to(x))
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
refs = {}

refs = {}

def f(x):
x = x * 1
y = F.expand_dims(x, [2, 3])
refs["x"] = TensorWeakRef(x)
return y
def f(x):
x = x * 1
y = F.expand_dims(x, [2, 3])
refs["x"] = TensorWeakRef(x)
return y

y = f(x)
for _, r in refs.items():
assert r() is None
y = f(x)
for _, r in refs.items():
assert r() is None
grad(y, F.ones_like(y))

grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((3, 3), dtype=np.float32), x.grad.numpy())


@@ -429,21 +427,21 @@ def test_removeAxis():
x_np = np.random.rand(3, 3, 1, 1).astype("float32")
x = mge.Tensor(x_np)

grad = Grad().wrt(x, callback=save_to(x))
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
refs = {}

refs = {}
def f(x):
x = x * 1
y = F.squeeze(x, [2, 3])
refs["x"] = TensorWeakRef(x)
return y

def f(x):
x = x * 1
y = F.squeeze(x, [2, 3])
refs["x"] = TensorWeakRef(x)
return y

y = f(x)
for _, r in refs.items():
assert r() is None
y = f(x)
for _, r in refs.items():
assert r() is None
grad(y, F.ones_like(y))

grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((3, 3, 1, 1), dtype=np.float32), x.grad.numpy())


@@ -452,11 +450,14 @@ def test_dot():
x = mge.Tensor(x)
u = F.ones((2,))
v = F.ones((2,))
grad = Grad().wrt(x, callback=save_to(x))

def f(x):
return F.dot(u, F.matmul(x, v))
with Grad() as grad:
grad.wrt(x, callback=save_to(x))

def f(x):
return F.dot(u, F.matmul(x, v))

y = f(x)
grad(y, F.ones_like(y))

y = f(x)
grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((2, 2), dtype=np.float32), x.grad.numpy())

+ 66
- 55
imperative/python/test/unit/functional/test_functional.py View File

@@ -267,25 +267,27 @@ def _gen_roi_inp():

def test_roi_align():
inp_feat, rois = _gen_roi_inp()
grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat))

output_shape = (7, 7)
out_feat = F.vision.roi_align(
inp_feat,
rois,
output_shape=output_shape,
mode="average",
spatial_scale=1.0 / 4,
sample_points=2,
aligned=True,
)
assert make_shape_tuple(out_feat.shape) == (
rois.shape[0],
inp_feat.shape[1],
*output_shape,
)
with Grad() as grad:
grad.wrt(inp_feat, callback=_save_to(inp_feat))

output_shape = (7, 7)
out_feat = F.vision.roi_align(
inp_feat,
rois,
output_shape=output_shape,
mode="average",
spatial_scale=1.0 / 4,
sample_points=2,
aligned=True,
)
assert make_shape_tuple(out_feat.shape) == (
rois.shape[0],
inp_feat.shape[1],
*output_shape,
)

grad(out_feat, tensor(F.ones_like(out_feat)))

grad(out_feat, tensor(F.ones_like(out_feat)))
assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape)


@@ -307,20 +309,23 @@ def _gen_correlation(random=True, constant=1, image_shape=(2, 1, 160, 160)):
def test_correlation():
##test case 0 check the grad shape
data1, data2 = _gen_correlation()
grad = Grad().wrt(data1, callback=_save_to(data1))

out_feat = F.vision.correlation(
data1,
data2,
kernel_size=5,
max_displacement=4,
stride1=2,
stride2=2,
pad_size=2,
is_multiply=True,
)
with Grad() as grad:
grad.wrt(data1, callback=_save_to(data1))

out_feat = F.vision.correlation(
data1,
data2,
kernel_size=5,
max_displacement=4,
stride1=2,
stride2=2,
pad_size=2,
is_multiply=True,
)

grad(out_feat, tensor(F.ones_like(out_feat)))

grad(out_feat, tensor(F.ones_like(out_feat)))
assert make_shape_tuple(data1.grad.shape) == make_shape_tuple(data1.shape)

##test case 1 from https://github.com/NVIDIA/flownet2-pytorch/issues/194
@@ -391,32 +396,36 @@ def test_correlation():

def test_roi_pooling():
inp_feat, rois = _gen_roi_inp()
grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat))
output_shape = (7, 7)
out_feat = F.vision.roi_pooling(
inp_feat, rois, output_shape=output_shape, mode="max", scale=1.0 / 4,
)
assert make_shape_tuple(out_feat.shape) == (
rois.shape[0],
inp_feat.shape[1],
*output_shape,
)
with Grad() as grad:
grad.wrt(inp_feat, callback=_save_to(inp_feat))
output_shape = (7, 7)
out_feat = F.vision.roi_pooling(
inp_feat, rois, output_shape=output_shape, mode="max", scale=1.0 / 4,
)
assert make_shape_tuple(out_feat.shape) == (
rois.shape[0],
inp_feat.shape[1],
*output_shape,
)

grad(out_feat, tensor(F.ones_like(out_feat)))

grad(out_feat, tensor(F.ones_like(out_feat)))
assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape)


def test_adaptive_avg_pool2d():
inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4))
oshp = (2, 2)
grad = Grad().wrt(inp, callback=_save_to(inp))
outp = F.adaptive_avg_pool2d(inp, oshp,)
assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,)
np.testing.assert_equal(
outp.numpy(), np.array([[[[2.5, 4.5], [10.5, 12.5]]]], dtype=np.float32)
)
with Grad() as grad:
grad.wrt(inp, callback=_save_to(inp))
outp = F.adaptive_avg_pool2d(inp, oshp,)
assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,)
np.testing.assert_equal(
outp.numpy(), np.array([[[[2.5, 4.5], [10.5, 12.5]]]], dtype=np.float32)
)

grad(outp, tensor(F.ones_like(outp)))

grad(outp, tensor(F.ones_like(outp)))
assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape)
np.testing.assert_equal(
inp.grad.numpy(),
@@ -439,14 +448,16 @@ def test_adaptive_avg_pool2d():
def test_adaptive_max_pool2d():
inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4))
oshp = (2, 2)
grad = Grad().wrt(inp, callback=_save_to(inp))
outp = F.adaptive_max_pool2d(inp, oshp,)
assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,)
np.testing.assert_equal(
outp.numpy(), np.array([[[[5, 7], [13, 15]]]], dtype=np.float32)
)
with Grad() as grad:
grad.wrt(inp, callback=_save_to(inp))
outp = F.adaptive_max_pool2d(inp, oshp,)
assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,)
np.testing.assert_equal(
outp.numpy(), np.array([[[[5, 7], [13, 15]]]], dtype=np.float32)
)

grad(outp, tensor(F.ones_like(outp)))

grad(outp, tensor(F.ones_like(outp)))
assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape)
np.testing.assert_equal(
inp.grad.numpy(),


+ 1
- 1
imperative/python/test/unit/functional/test_tensor.py View File

@@ -351,7 +351,7 @@ def test_expand_dims_for_scalar():

for axis in [1, -2, (1, 2), (-2, -3)]:
np.testing.assert_raises(np.AxisError, np.expand_dims, x, axis)
np.testing.assert_raises(AssertionError, F.expand_dims, xx, axis)
np.testing.assert_raises(RuntimeError, F.expand_dims, xx, axis)


@pytest.mark.parametrize("is_varnode", [True, False])


+ 79
- 21
imperative/python/test/unit/jit/test_tracing.py View File

@@ -9,6 +9,7 @@
import inspect
import io
import itertools
import random
from tempfile import mkstemp

import numpy as np
@@ -25,7 +26,7 @@ from megengine.core.ops import builtin as ops
from megengine.core.ops.builtin import Elemwise
from megengine.core.tensor.utils import isscalar
from megengine.functional import exp, log
from megengine.jit import GraphOptimizationConfig, exclude_from_trace, trace
from megengine.jit import GraphOptimizationConfig, TraceError, exclude_from_trace, trace
from megengine.module import Module
from megengine.random import normal, uniform
from megengine.utils.naming import AutoNaming
@@ -464,36 +465,92 @@ def test_trace_warp_perspective():
f(x, M)


def test_raise_on_trace():
step_count = 0
catch_count = 0
bad_step = 10
@pytest.mark.parametrize(
"normal_expr, mismatch_expr, reason",
[
("a + b + c", "a + b - c", "operator mismatch"),
("a + b + 1", "a + b + 2", "tensors not equals"),
("((a + b), (b + c))[0]", "a + b", "mismature end"),
("a + b + c", "c + (a + b)", "expect internal node, got external"),
("c + (a + b)", "a + b + c", "expect external node, got internal"),
("a + b + c", "a + b + c + c", "too many instructions"),
("((a + b), (b + c))[1]", "((a + b), (b + c))[0]", "data unreadable"),
("((a + b), (b + c))[1] + a", "((a + b), (b + c))[0] + a", "input id mismatch"),
],
)
def test_trace_mismatch(normal_expr, mismatch_expr, reason):
a = tensor([1, 2, 3, 4])
b = tensor([5, 6, 7, 8])
c = tensor([9, 0, 1, 2])

mismatch = False

@trace(symbolic=True)
def fn(a, b, c):
if not mismatch:
result = eval(normal_expr)
else:
result = eval(mismatch_expr)
return result

for i in range(20):
try:
d = fn(a, b, c)
except TraceError as e:
assert mismatch
assert str(e) == "trace error because {}".format(reason)
except:
pytest.fail("unexpected trace error")
else:
assert not mismatch
np.testing.assert_equal(d.numpy(), eval(normal_expr).numpy())
mismatch = random.random() > 0.8

class CatchMe(Exception):
pass

def test_exception_in_trace():
a = tensor([1, 2, 3, 4])
b = tensor([5, 6, 7, 8])
c = tensor([9, 0, 1, 2])

@trace
def add_abc(a, b, c):
ps = a + b
result = ps + c
if step_count == bad_step:
raise CatchMe("catch me")
mismatch = False

exc = Exception()

@trace(symbolic=True)
def fn(a, b, c):
result = a + b
if not mismatch:
result += c
else:
raise exc
return result

for i in range(100):
for i in range(20):
try:
d = add_abc(a, b, c)
except CatchMe as e:
catch_count += 1
d = fn(a, b, c)
except TraceError as e:
pytest.fail("unexpected trace error")
except Exception as e:
assert mismatch
assert e is exc
else:
assert not mismatch
np.testing.assert_equal(d.numpy(), (a + b + c).numpy())
step_count += 1
mismatch = random.random() > 0.8

assert catch_count == 1

def test_graph_error():
a = tensor(np.arange(8).reshape((2, 4)))
b = tensor(np.arange(8).reshape((2, 4)))

@trace(symbolic=True)
def fn(a, b):
return a + b

fn(a, b)
with pytest.raises(RuntimeError):
fn(a, b.transpose())
fn(a, b)


@pytest.mark.parametrize("trace_mode", [False, True])
@@ -653,9 +710,10 @@ def test_trace_jit_config():

x = tensor(2)
y = func(x)
func._compile()
y = func(x)
# func._compile()

options = func._graph.options
options = func._trace.options
mapping = {None: 0, False: 1, True: 2}
assert options.graph_opt.jit == 0
assert options.graph_opt.jit_config.fuse_dimshuffle == mapping[fuse_dimshuffle]


+ 16
- 12
imperative/python/test/unit/quantization/test_fake_quant.py View File

@@ -82,9 +82,10 @@ def test_tqt():
x = mge.tensor(x, dtype="float32")
s = mge.tensor(s, dtype="float32")
g_y = mge.tensor(g_y, dtype="float32")
grad = Grad().wrt(x, s, callback=cb)
y = tqt_forward(-127, 127, x, s)
grad(y, g_y)
with Grad() as grad:
grad.wrt(x, s, callback=cb)
y = tqt_forward(-127, 127, x, s)
grad(y, g_y)
g_x, g_s = g

np.testing.assert_allclose(y.numpy(), y_np, rtol=1e-5, atol=1e-5)
@@ -131,14 +132,16 @@ def test_fakequant():

# test backward
x = tensor(inp_data, dtype=np.float32)
grad = Grad().wrt(x, callback=_save_to(x))
y = fake_quant_tensor(x, qparams)
grad(y, tensor(F.ones_like(x)))
with Grad() as grad:
grad.wrt(x, callback=_save_to(x))
y = fake_quant_tensor(x, qparams)
grad(y, tensor(F.ones_like(x)))

x1 = tensor(inp_data, dtype=np.float32)
grad = Grad().wrt(x1, callback=_save_to(x1))
y1 = fake_quant_tensor_gt(x1, scale, zero_point, qmin, qmax)
grad(y1, tensor(F.ones_like(x1)))
with Grad() as grad:
grad.wrt(x1, callback=_save_to(x1))
y1 = fake_quant_tensor_gt(x1, scale, zero_point, qmin, qmax)
grad(y1, tensor(F.ones_like(x1)))

assert np.allclose(x.grad.numpy(), x1.grad.numpy())
assert make_shape_tuple(x.grad.shape) == make_shape_tuple(x1.grad.shape)
@@ -237,9 +240,10 @@ def test_lsq():
grad_s = mge.tensor(grad_s, dtype="float32")

g_y = mge.tensor(g_y, dtype="float32")
grad = Grad().wrt(x, s, callback=cb)
y = lsq_forward(-127, 127, x, s, zero_point, grad_s)
grad(y, g_y)
with Grad() as grad:
grad.wrt(x, s, callback=cb)
y = lsq_forward(-127, 127, x, s, zero_point, grad_s)
grad(y, g_y)
g_x, g_s = g

np.testing.assert_allclose(y.numpy(), y_np, rtol=1e-7, atol=1e-7)


+ 4
- 3
imperative/python/test/unit/random/test_rng.py View File

@@ -430,9 +430,10 @@ def test_ShuffleRNG():
n, m = 6, 3
arr = np.arange(n * m)
out0 = Tensor(arr, dtype="float32")
grad = Grad().wrt(out0, callback=cb)
random.shuffle(out0)
grad(out0, F.ones_like(out0))
with Grad() as grad:
grad.wrt(out0, callback=cb)
random.shuffle(out0)
grad(out0, F.ones_like(out0))
m1 = RNG(seed=111, device="xpu0")
m2 = RNG(seed=111, device="xpu1")
m3 = RNG(seed=222, device="xpu0")


Loading…
Cancel
Save