diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index a754d8ef..5fab6f14 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -28,9 +28,6 @@ class AttachSpec: __slots__ = "tensor", "callbacks" -_global_priority = 0 - - class GradManager: r"""GradManager computes gradients or more generally, vector-Jacobian product, by reverse mode automatic differentiation (a.k.a. back propagation). @@ -127,7 +124,6 @@ class GradManager: self._grad = None self._after_backward_callback = [] self._gradients = {} - self._priority = None def attached_tensors(self): r"""Return attached tensor list from :meth:`attach`.""" @@ -299,31 +295,25 @@ class GradManager: tensor.grad = grad else: tensor.grad += grad - if tensor._isscalar() and tensor.grad is not None: - tensor.grad._setscalar() finally: self.release() backwarding_grad_manager = cache - set_option("record_computing_path", 1) - pop_scope("backward") + set_option("record_computing_path", 1) + pop_scope("backward") def record(self): r"""Start recording operations After this call, you will be able to call :meth:`backward`. """ - global _global_priority if self._recording: raise RuntimeError("already recording") grad = Grad() self._recording = True self._grad = grad + grad.__enter__() for spec in self._attach_specs.values(): self._do_record(spec) - if self._priority is None: - grad._priority = _global_priority - _global_priority -= 1 - grad.__enter__() def _do_record(self, spec): tensor = spec.tensor() @@ -331,6 +321,8 @@ class GradManager: return def callback(grad, callbacks=spec.callbacks): + from ..functional import ones_like + for cb in callbacks: grad = cb(tensor, grad) self._gradients[id(tensor)] = grad @@ -343,14 +335,11 @@ class GradManager: After this call, you will not be able to call :meth:`backward`. """ - global _global_priority if self._grad is not None: self._grad.__exit__(None, None, None) self._grad = None self._recording = False self._gradients = dict() - if self._priority is None: - _global_priority += 1 def __enter__(self): self.record() @@ -382,15 +371,14 @@ class GradManagerGroup: __ror__ = merge_with def __enter__(self): - global _global_priority - _global_priority += 1 + Grad.stack.append([]) + Grad.begin_group() for gm in self._gms: - gm._priority = _global_priority gm.record() + assert gm._grad is not None + Grad.end_group() def __exit__(self, exc_type, exc_val, exc_tb): - global _global_priority - _global_priority -= 1 - for gm in self._gms: + for gm in reversed(self._gms): gm.release() - gm._priority = None + assert gm._grad is None diff --git a/imperative/python/megengine/core/autodiff/grad.py b/imperative/python/megengine/core/autodiff/grad.py index 69a602da..5f4c011b 100644 --- a/imperative/python/megengine/core/autodiff/grad.py +++ b/imperative/python/megengine/core/autodiff/grad.py @@ -6,17 +6,9 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import functools -import heapq -import itertools -import typing import weakref -import numpy as np - -from .._imperative_rt import core2, ops -from ..ops.builtin import Elemwise, OpDef, RemoteSend -from ..ops.special import Const +from .._imperative_rt import core2 _grad_count = 0 _grad_manager_dict = weakref.WeakValueDictionary() @@ -36,6 +28,10 @@ class GradKey(core2.GradKey): class Grad: + stack = [] + grouping = False + key2grad = weakref.WeakValueDictionary() + def __init__(self, name=None): global _grad_count if name is None: @@ -43,15 +39,9 @@ class Grad: _grad_count += 1 self._refkeeper = [] self._impl = GradKey(name) + Grad.key2grad[self._impl] = self _grad_manager_dict[self._name] = self - - @property - def _priority(self): - return self._impl.priority - - @_priority.setter - def _priority(self, priority): - self._impl.priority = priority + self._group = [weakref.ref(self)] @property def _name(self): @@ -70,33 +60,80 @@ class Grad: if not isinstance(ys, Sequence): ys = [ys] + if not isinstance(dys, Sequence): dys = [dys] + group = [ref() for ref in self._group] + + for grad in group: + if grad is self: + continue + grad.suppress() + self._impl.backward(ys, dys) + for grad in group: + if grad is self: + continue + grad.resume() + self._refkeeper = None + return None def __enter__(self): + ref = weakref.ref(self) + self._impl.enter() + if Grad.grouping: + group = Grad.stack[-1] + self._group = group + group.append(ref) + else: + Grad.stack.append(self._group) return self def __exit__(self, _1, _2, _3): + self._impl.exit() self._refkeeper = None - del self._impl - - -class Function(ops.PyOpBase): + del Grad.key2grad[self._impl] + self._impl = None + self._group.remove(weakref.ref(self)) + if len(self._group) == 0: + Grad.stack.remove(self._group) + + @staticmethod + def begin_group(): + assert not Grad.grouping + Grad.grouping = True + + @staticmethod + def end_group(): + group = Grad.stack[-1] + assert len(group) > 0 + assert Grad.grouping + Grad.grouping = False + + def suppress(self): + if self._impl is not None: + self._impl.suppress() + + def resume(self): + if self._impl is not None: + self._impl.resume() + + +class Function: r"""Defines a block of operations with customizable differentiation. - + The computation should be defined in ``forward`` method, with gradient computation defined in ``backward`` method. - + Each instance of ``Function`` should be used only once during forwardding. - + Examples: - + .. code-block:: - + class Sigmoid(Function): def forward(self, x): y = 1 / (1 + F.exp(-x)) @@ -115,7 +152,7 @@ class Function(ops.PyOpBase): Returns: a tuple of Tensor or a single Tensor. - + Note: * This method should return a tuple of Tensor or a single Tensor representing the output of the function. @@ -128,7 +165,7 @@ class Function(ops.PyOpBase): Args: output_grads: gradients of outputs that are returned by :meth:`forward`. - + Note: * In case when some tensors of outputs are not related to loss function, the corresponding values in ``output_grads`` would be ``None``. @@ -148,10 +185,40 @@ class Function(ops.PyOpBase): return self._default_rule(*args), self.backward def __call__(self, *args): - ret = core2.apply(self, *args) + for arg in args: + if not isinstance(arg, core2.Tensor): + raise TypeError( + "op Function expect type Tensor as inputs, got {}".format(type(arg)) + ) + + grad_key = core2.get_grad_key(args) + if grad_key is None: + return self._default_rule(*args) + + grad = Grad.key2grad[grad_key] + group = [ref() for ref in grad._group] + + for grad in group: + grad.suppress() + outputs, backward = self._grad_rule(*args) + for grad in reversed(group): + grad.resume() + + def normalized_backward(*output_grads): + input_grads = backward(*output_grads) + if isinstance(input_grads, core2.Tensor) or input_grads is None: + input_grads = (input_grads,) + return input_grads + if self.__single_output: - (ret,) = ret - return ret + outputs = (outputs,) + for grad in reversed(group): + if grad._impl is None: + continue + outputs = core2.set_grad(grad._impl, normalized_backward, args, outputs) + if self.__single_output: + (outputs,) = outputs + return outputs def __getstate__(self): return self.__dict__ diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index 92593475..fe15d2bf 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -26,7 +26,6 @@ from .utils import ( convert_inputs, isscalar, make_shape_tuple, - setscalar, ) _ElwMod = builtin.Elemwise.Mode @@ -34,14 +33,7 @@ _ElwMod = builtin.Elemwise.Mode def _elwise_apply(args, mode): op = builtin.Elemwise(mode) - _isscalar = True - for i in args: - if isscalar(i) == False: - _isscalar = False - break (result,) = apply(op, *args) - if _isscalar: - setscalar(result) return result @@ -203,8 +195,6 @@ def _remove_axis(inp: Tensor, axis) -> Tensor: op = builtin.RemoveAxis(axis=axis) (result,) = apply(op, inp) - if len(axis) == inp.ndim: - setscalar(result) return result @@ -221,6 +211,7 @@ def _reduce(mode): op = builtin.Reduce(mode=mode, axis=0) (result,) = apply(op, data) + result = _remove_axis(result, 0) elif isinstance(axis, collections.abc.Iterable): axis = _normalize_axis(self.ndim, axis, reverse=True) for ai in axis: @@ -239,8 +230,6 @@ def _reduce(mode): if self.dtype == np.bool_: if mode in ["min", "max"]: result = result.astype("bool") - if axis is None or self.ndim == 1: - setscalar(result) return result return f @@ -457,7 +446,6 @@ class ArrayMethodMixin(abc.ABC): len(args) == 0 ), "transpose for scalar does not accept additional args" ret = self.to(self.device) - setscalar(ret) return ret if not args: args = range(self.ndim)[::-1] diff --git a/imperative/python/megengine/core/tensor/indexing.py b/imperative/python/megengine/core/tensor/indexing.py index b559d18b..f5513639 100644 --- a/imperative/python/megengine/core/tensor/indexing.py +++ b/imperative/python/megengine/core/tensor/indexing.py @@ -111,7 +111,6 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): if not isinstance(tuple_val, tuple): tuple_val = (tuple_val,) ndim_indexed = 0 - ndim_indexed_scalar = 0 for i in tuple_val: if not i is Ellipsis: ndim_indexed += ( @@ -119,14 +118,6 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): if hasattr(i, "dtype") and i.dtype == np.bool_ and hasattr(i, "ndim") else 1 ) - if isscalar(i): - ndim_indexed_scalar += 1 - ret_scalar = False - try: - ret_scalar = ndim_indexed_scalar == inp.ndim - except ValueError: - # inp.ndim is unknown - pass else: if ndim_indexed > inp.ndim: raise IndexError( @@ -221,7 +212,7 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): items.append(item) if new_axes: raise IndexError("newaxis is not allowed here") - return inp, tensors, items, use_subtensor, ret_scalar + return inp, tensors, items, use_subtensor def try_condtake(tensor, index): @@ -247,14 +238,12 @@ def getitem(tensor, index): try_result = try_condtake(tensor, index) if len(try_result) == 2: return try_result[0] - tensor, tensors, items, use_subtensor, ret_scalar = unpack_getitem(tensor, index) + tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index) if use_subtensor: op = builtin.Subtensor(items=items) else: op = builtin.IndexingMultiAxisVec(items=items) (result,) = apply(op, tensor, *tensors) - if ret_scalar: - result._setscalar() return result @@ -266,7 +255,7 @@ def setitem(tensor, index, value): tensor = tensor.reshape(-1) if not isinstance(value, (Tensor, SymbolVar)): (value,) = Const(value, dtype=tensor.dtype, device=tensor.device)(tensor) - tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) + tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index) if use_subtensor: op = builtin.Subtensor(items=items) else: diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index 1f7f2907..4c2edc92 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -17,6 +17,7 @@ import numpy as np from .. import _imperative_rt from .._imperative_rt import GraphOptimizeOptions, SerializationFormat +from .._imperative_rt.core2 import apply from .._wrap import as_device from ..ops.builtin import OpDef @@ -126,9 +127,8 @@ class Graph(_imperative_rt.ComputingGraph): class VarNode: - def __init__(self, node: _imperative_rt.VarNode, isscalar=False): + def __init__(self, node: _imperative_rt.VarNode): self._node = node - self._isscalar = isscalar if hasattr(self.graph, "_var_cache"): self.graph._var_cache[node] = self @@ -530,9 +530,6 @@ def _unwrap(x): def apply_normal_varnode(op: OpDef, *args: VarNode): - # for PyOp like RemoteSend/Recv - if getattr(op, "op", None): - op = op.op outputs = _imperative_rt.invoke_op(op, _unwrap(args)) return _wrap(outputs) diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index 0b2b2104..c9431dc9 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -51,10 +51,7 @@ def concatenate(inputs, axis=0, *, device=None): def astype(x, dtype): dtype = np.dtype(dtype) if not is_dtype_equal(x.dtype, dtype): - isscalar = x._isscalar() (x,) = apply(builtin.TypeCvt(dtype=dtype), x) - if isscalar: - x._setscalar() return x @@ -129,13 +126,6 @@ def isscalar(x): return np.isscalar(x) -def setscalar(x): - if isinstance(x, (Tensor, SymbolVar)): - x._setscalar() - else: - raise NotImplementedError("Unsupport type {}".format(type(x))) - - def astensor1d(x, *reference, dtype=None, device=None): """Convert something to 1D tensor. Support following types @@ -237,6 +227,7 @@ for name, mode in [ ("**", "pow"), ("max", "max"), ("additive", "add"), + ("exp", "EXP"), ]: _opr_map[(name, 2)] = builtin.Elemwise(mode=mode) diff --git a/imperative/python/megengine/distributed/functional.py b/imperative/python/megengine/distributed/functional.py index ca4d27cd..51f1f064 100644 --- a/imperative/python/megengine/distributed/functional.py +++ b/imperative/python/megengine/distributed/functional.py @@ -13,7 +13,7 @@ import numpy as np from ..core._imperative_rt.core2 import apply from ..core.autodiff.grad import Function, _grad_manager_dict from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend -from ..core.tensor.utils import isscalar, setscalar +from ..core.tensor.utils import isscalar from ..device import get_default_device, what_is_xpu from ..tensor import Tensor from . import group @@ -72,15 +72,6 @@ def collective_comm(inp, mode, group, device): ) (result,) = apply(op, inp) # assume all workers have homogeneous shape - if mode in ( - CollectiveComm.Mode.REDUCE_SUM, - CollectiveComm.Mode.BROADCAST, - CollectiveComm.Mode.ALL_REDUCE_SUM, - CollectiveComm.Mode.ALL_REDUCE_MAX, - CollectiveComm.Mode.ALL_REDUCE_MIN, - ): - if isscalar(inp): - setscalar(result) return result @@ -190,8 +181,7 @@ def reduce_sum( # Rank 0 # output: None # Rank 1 # output: Tensor([1]) """ - op = _ReduceSum(group, device) - (out,) = apply(op, inp) + out = _ReduceSum(group, device)(inp) if group.rank == 0: return out @@ -258,8 +248,7 @@ def broadcast( _bcast_tracer_state(group, inp) - op = _Broadcast(group, device) - (out,) = apply(op, inp) + out = _Broadcast(group, device)(inp) return out @@ -604,8 +593,7 @@ def gather( inp.shape ) - op = _Gather(group, device) - (out,) = apply(op, inp) + out = _Gather(group, device)(inp) if group.rank == 0: if axis == 0: @@ -708,8 +696,7 @@ def scatter( + [_ for _ in range(axis + 1, inp.ndim + 1)] ) inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape) - op = _Scatter(group, device) - (out,) = apply(op, inp) + out = _Scatter(group, device)(inp) return out @@ -832,7 +819,7 @@ class _RemoteRecv(Function): self.op = op def forward(self, dummy): - return apply(self.op, dummy) + return apply(self.op, dummy)[0] def backward(self, grad): get_client().bcast_val(grad is not None, self.op.key, 2) @@ -871,7 +858,7 @@ def remote_send(inp: Tensor, dest_rank: int): op.addr, op.port = get_mm_server_addr() op.rank_to = dest_rank op.backend = _backend() - (out,) = apply(_RemoteSend(op), inp) + out = _RemoteSend(op)(inp) _save_output_for_autodiff(inp, out) @@ -912,11 +899,6 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor inp = Tensor(0, device=device) _bcast_tracer_state(group, inp) - _isscalar = False - if len(shape) == 0: - shape = (1,) - _isscalar = True - op = RemoteRecv() op.key = group.key op.cn = device @@ -926,7 +908,5 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor op.rank_from = src_rank op.backend = _backend() - (ret,) = apply(_RemoteRecv(op), inp) - if _isscalar: - setscalar(ret) + ret = _RemoteRecv(op)(inp) return ret diff --git a/imperative/python/megengine/distributed/helper.py b/imperative/python/megengine/distributed/helper.py index 9998a63c..b70e2da1 100644 --- a/imperative/python/megengine/distributed/helper.py +++ b/imperative/python/megengine/distributed/helper.py @@ -67,9 +67,6 @@ def param_pack_split(inp: Tensor, offsets: list, shapes: list): op.offsets = offsets op.shapes = [s or (1,) for s in shapes] outputs = apply(op, inp) - for s, x in zip(shapes, outputs): - if not s: - x._setscalar() return outputs diff --git a/imperative/python/megengine/experimental/autograd.py b/imperative/python/megengine/experimental/autograd.py deleted file mode 100644 index 8c8b5d25..00000000 --- a/imperative/python/megengine/experimental/autograd.py +++ /dev/null @@ -1,25 +0,0 @@ -# -*- coding: utf-8 -*- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - -from ..core._imperative_rt.core2 import ( - set_allow_higher_order_directive as _set_allow_higher_order_directive, -) - -__all__ = [ - "enable_higher_order_directive", - "disable_higher_order_directive", -] - - -def enable_higher_order_directive(): - _set_allow_higher_order_directive(True) - - -def disable_higher_order_directive(): - _set_allow_higher_order_directive(False) diff --git a/imperative/python/megengine/functional/inplace.py b/imperative/python/megengine/functional/inplace.py index 05126f50..e6766a8e 100644 --- a/imperative/python/megengine/functional/inplace.py +++ b/imperative/python/megengine/functional/inplace.py @@ -12,8 +12,5 @@ from ..core.ops.builtin import InplaceAdd def _inplace_add_(dest, delta, alpha, beta): - isscalar = dest._isscalar() dest._reset(apply(InplaceAdd(), dest, delta, alpha, beta)[0]) - if isscalar: - dest._setscalar() return dest diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index 9ddf15cd..3690f562 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -19,7 +19,7 @@ from ..core.ops import builtin from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt from ..core.ops.special import Const from ..core.tensor import amp -from ..core.tensor.utils import _normalize_axis, cast_tensors, setscalar, subgraph +from ..core.tensor.utils import _normalize_axis, cast_tensors, subgraph from ..jit import exclude_from_trace from ..tensor import Tensor from ..utils.deprecation import deprecated_kwargs_default @@ -1149,7 +1149,6 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor: inp1.ndim <= 1 and inp2.ndim <= 1 ), "Input tensors for dot must be 1-dimensional or scalar" (result,) = apply(op, inp1, inp2) - setscalar(result) return result @@ -1200,5 +1199,4 @@ def _check_non_finite(inps: Iterable[Tensor], scale=1.0) -> Tensor: for i in range(len(inps)): inps[i]._reset(oups[i]) - out._setscalar() return out diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 025b30ea..a7220993 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -35,7 +35,6 @@ from ..core.tensor.utils import ( cast_tensors, convert_single_value, make_shape_tuple, - setscalar, subgraph, ) from ..device import get_default_device diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index ac29ae16..aead108e 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -972,13 +972,6 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: ) axis = sorted(axis) assert axis, "axis could not be empty" - if inp._isscalar(): - assert axis[0] == 0, "invalid axis {} for ndim 0".format(axis[0]) - if len(axis) == 1: - inp = copy(inp, device=None) - inp._unsetscalar() - return inp - axis = axis[1:] op = builtin.AddAxis(axis=axis) (result,) = apply(op, inp) return result @@ -1164,8 +1157,6 @@ def repeat(inp: Tensor, repeats: int, axis: Optional[int] = None): if axis is None: inp = inp.reshape(-1) # flatten axis = 0 - if inp._isscalar(): - inp._unsetscalar() shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device) # assume inp.ndim is not changed during trace max_axis = len(shape) - 1 diff --git a/imperative/python/megengine/jit/__init__.py b/imperative/python/megengine/jit/__init__.py index bd50925e..455f3502 100644 --- a/imperative/python/megengine/jit/__init__.py +++ b/imperative/python/megengine/jit/__init__.py @@ -6,19 +6,7 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from ..core._imperative_rt.core2 import ( - set_cpp_apply_const_with_tracing, - set_cpp_apply_with_tracing, -) from .dtr_config import DTRConfig from .graph_opt_config import GraphOptimizationConfig from .sublinear_memory_config import SublinearMemoryConfig -from .tracing import ( - apply_const_with_tracing, - apply_with_tracing, - exclude_from_trace, - trace, -) - -set_cpp_apply_with_tracing(apply_with_tracing) -set_cpp_apply_const_with_tracing(apply_const_with_tracing) +from .tracing import TraceError, exclude_from_trace, trace diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 19f51544..a5aa567e 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -15,24 +15,18 @@ import os import pickle import re import struct +import sys from typing import Any import cv2 import numpy as np -from megengine.logger import get_logger - from .. import tensor from ..core import _imperative_rt as rt from ..core._imperative_rt import GraphProfiler, GraphProfiler2, SerializationMetadata from ..core._imperative_rt.core2 import Tensor as RawTensor -from ..core._imperative_rt.core2 import ( - TensorWeakRef, - apply, - set_tracing, - skip_tracing, - unset_tracing, -) +from ..core._imperative_rt.core2 import Trace, TraceError, name_tensor # skip_tracing, +from ..core._imperative_rt.graph import _set_priority_to_id from ..core._imperative_rt.ops import ( AssertEqual, CollectiveComm, @@ -41,10 +35,8 @@ from ..core._imperative_rt.ops import ( RemoteSend, ) from ..core._trace_option import set_symbolic_shape -from ..core._wrap import as_device -from ..core.ops.builtin import BatchNorm, OpDef from ..core.tensor import megbrain_graph as G -from ..core.tensor.utils import setscalar +from ..logger import get_logger from ..utils import comp_graph_tools as cgtools from ..utils.naming import AutoNaming from ..utils.profiler import is_profiling @@ -59,11 +51,8 @@ def _input_node_use_static_shape(): return os.environ.get("MEGENGINE_INPUT_NODE_USE_STATIC_SHAPE") is not None -class TraceMismatchError(RuntimeError): - pass - - active_trace = None +skip_tracing = False def is_tracing(): @@ -81,59 +70,17 @@ def exclude_from_trace(): return try: skip_tracing = True - unset_tracing() if active_trace is not None: active_trace._begin_excluded_region() yield + if active_trace is not None: + active_trace._end_excluded_region() finally: skip_tracing = False - set_tracing() - - -class TensorInfo: - __slots__ = ( - # collected attributes - "name", - "external", - "data_read", - "shape_read", - "value_read", - "exported", - "device", - "dtype", - "shape", - "is_const", - "bound_data", - "bound_data_numpy", - # resources for execution - "varnode", - "data_setter", - "shape_reader", - "value_reader", - "data_reader", - ) - - def __init__(self): - self.name = None - self.exported = None - self.data_read = None - self.shape_read = None - self.value_read = None - self.bound_data = None - self.bound_data_numpy = None - - self.data_setter = None - self.shape_reader = None - self.value_reader = None - self.data_reader = None - - def get_numpy(self): - if self.bound_data_numpy is None: - self.bound_data_numpy = self.bound_data.numpy() - return self.bound_data_numpy - - -_io_op_types = {AssertEqual, CollectiveComm, RemoteSend, RemoteRecv} + + +def array_comparator(lhs, rhs): + return np.all(lhs == rhs) class trace: @@ -174,539 +121,145 @@ class trace: symbolic_shape: bool = True, ): self.__wrapped__ = function - self._symbolic = symbolic or record_only self._capture_as_const = capture_as_const or record_only - self._record_only = record_only - self._sublinear_memory_config = sublinear_memory_config - self._dtr_config = dtr_config - self._profiling = profiling - self._profiler = None - self._profiler2 = None - self._graph_opt_level = opt_level - self._graph_opt_config = graph_opt_config - self._symbolic_shape = symbolic_shape - self._output_handles = set() - - self._reset() - - def _reset(self): - self._untraced = True - self._tinfo = [] # handle -> TensorInfo - self._seq = [] - self._pc = 0 - self._graph = None - self._need_reset_nodes = None - self._lazy_eval_graph = None - self._lazy_eval_tensors = set() - self._lazy_eval_links = None - self._active_tensors = set() - self._tensor_remaps = None - self._inputs_to_restore = None self._arg_bindings = None self._kwarg_bindings = None self._output_bindings = None - self._output_names = None - - def _new_handle(self): - handle = len(self._tinfo) - info = TensorInfo() - self._tinfo.append(info) - return handle, info - - def _apply_op(self, op, args): - assert not self._untraced - # check against trace - if self._pc >= len(self._seq): - raise TraceMismatchError("trace should end here, but more op observed") - record = self._seq[self._pc] - op_, ihandles, ohandles = record - if (isinstance(op_, str) and op_ == "Const") or (op != op_): - raise TraceMismatchError("op different from last time") - if len(ihandles) != len(args): - raise TraceMismatchError("op input size different from last time") - - # check all inputs of crrent op - for h, x in zip(ihandles, args): - info = self._tinfo[h] - if info.external: - if ( - x._compiled_info is not None - and not self._tinfo[x._mixin_handle].exported - ): - raise TraceMismatchError( - "failed to capture: input was an external tensor " - "last time, got an internal tensor this time" - ) - if info.bound_data: - if x._compiled_info is not None: - raise TraceMismatchError( - "const capture violated: was an external tensor " - "last time, got an internal tensor this time" - ) - if x._handle != info.bound_data._handle: - if not np.array_equal(x.numpy(), info.bound_data.numpy()): - raise TraceMismatchError( - "const capture violated: got " - "a different tensor this time" - ) - else: - if info.dtype != x.dtype: - raise TraceMismatchError( - "failed to capture: different dtype from last time" - ) - if info.device != x.device: - raise TraceMismatchError( - "failed to capture: different device from last time" - ) - info.data_setter.set_value(x._dev_tensor()) - else: - if x._mixin_handle == -1: - if x._handle not in self._tensor_remaps: - raise TraceMismatchError( - "unexpected capture: trying to use an external tensor as " - "input, but that input was an internal tensor last time" - ) - else: - x._mixin_handle = self._tensor_remaps[ - x._handle - ]._CompiledTensorProxy__handle - if x._mixin_handle != h: - raise TraceMismatchError( - "mis-wiring: input edge to an data flow " - "graph node is different from last time" - ) + self._symbolic_shape = symbolic_shape + self._graph_options = { + "no_force_inplace": True, + "graph_opt_level": opt_level, + "seq_opt.enable_seq_comp_node_opt": False, + } + + # prevent cyclic reference + graph_options = self._graph_options + if dtr_config is not None: + graph_options["enable_dtr_memory_opt"] = True + graph_options[ + "dtr_config.eviction_threshold" + ] = dtr_config.eviction_threshold + graph_options[ + "dtr_config.evictee_minimum_size" + ] = dtr_config.evictee_minimum_size + graph_options[ + "dtr_config.recomp_memory_factor" + ] = dtr_config.recomp_memory_factor + graph_options[ + "dtr_config.recomp_time_factor" + ] = dtr_config.recomp_time_factor + if graph_opt_config is not None: + mapping = {None: 0, False: 1, True: 2} + graph_options["graph_opt.jit_config.fuse_dimshuffle"] = mapping[ + graph_opt_config.jit_fuse_dimshuffle + ] + graph_options["graph_opt.jit_config.fuse_reduce"] = mapping[ + graph_opt_config.jit_fuse_reduce + ] + if sublinear_memory_config is not None: + graph_options["enable_sublinear_memory_opt"] = True + graph_options[ + "sublinear_mem_config.lb_memory_mb" + ] = sublinear_memory_config.lb_memory_mb + graph_options[ + "sublinear_mem_config.genetic_nr_iter" + ] = sublinear_memory_config.genetic_nr_iter + graph_options[ + "sublinear_mem_config.genetic_pool_size" + ] = sublinear_memory_config.genetic_pool_size + graph_options[ + "sublinear_mem_config.thresh_nr_try" + ] = sublinear_memory_config.thresh_nr_try + graph_options[ + "sublinear_mem_config.num_worker" + ] = sublinear_memory_config.num_worker + if int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")): + graph_options["var_sanity_check_first_run"] = False + + def apply_options(options): + for k, v in graph_options.items(): + words = k.split(".") + suboptions = options + for word in words[:-1]: + suboptions = getattr(suboptions, word) + setattr(suboptions, words[-1], v) + + self._trace = Trace() + self._trace.symbolic = symbolic or record_only + self._trace.capture_as_const = capture_as_const or record_only + self._trace.no_exec = record_only + self._trace.options_visitor = apply_options + self._trace.profile = profiling + self._trace.array_comparator = array_comparator + self._trace.record_input_shapes = _input_node_use_static_shape() - self._pc += 1 - outputs = [] - for h in ohandles: - info = self._tinfo[h] - # generate output tensor and create compied info - y = RawTensor(info.varnode) - y._compiled_info = CompiledTensorProxy(h) - y._mixin_handle = h - outputs += [y] - self._active_tensors.add(TensorWeakRef(y)) - self._output_handles.update(ohandles) + def __call__(self, *args, **kwargs): + global active_trace + symbolic_shape = None + outputs = None + try: + active_trace = self + self._trace.enter() + if self._capture_as_const: + self._process_inputs(*args, **kwargs) + symbolic_shape = set_symbolic_shape(self._symbolic_shape) + outputs = self.__wrapped__(*args, **kwargs) + finally: + handling_exc = sys.exc_info() != (None,) * 3 + active_trace = None + if symbolic_shape is not None: + symbolic_shape = set_symbolic_shape(symbolic_shape) + assert symbolic_shape == self._symbolic_shape + if self._capture_as_const and (outputs is not None): + self._process_outputs(outputs) + try: + # may raise TraceError + self._trace.exit() + except TraceError: + if not handling_exc: + raise return outputs - def _apply_const(self, value, dtype, device): - assert not self._untraced - # check against trace - if self._pc >= len(self._seq): - raise TraceMismatchError("trace should end here, but more op observed") - record = self._seq[self._pc] - op_, ihandles, ohandles = record - # Const op is represented by a str - assert isinstance(op_, str) and op_ == "Const" - - expected = self._tinfo[ohandles[0]].get_numpy() - shape = value.shape - if shape != expected.shape or dtype != expected.dtype: - eq = False - elif shape == (): - eq = expected.item() == value.item() - elif shape == (1,): - eq = expected[0] == value[0] - else: - eq = np.all(value == expected) - if not eq: - raise TraceMismatchError( - "const tensor violated: got a different tensor this time" - ) + def _process_inputs(self, *args, **kwargs): + for i, arg in enumerate(args): + name_tensor("arg_{}".format(i), arg) - self._pc += 1 - (h,) = ohandles - outputs = [self._tinfo[h].bound_data] - return outputs + # TODO: mark kwargs in order + for k, kwarg in kwargs.items(): + if isinstance(kwarg, RawTensor): + name_tensor("kwarg_{}".format(k), kwarg) - # run in first step, record information for trace - def _record_op(self, op, inputs, outputs): - if skip_tracing: - for x in inputs: - h = getattr(x, "_mixin_handle", -1) - if h >= 0: - self._tinfo[h].data = True - return - - ihandles = [] - for x in inputs: - h = getattr(x, "_mixin_handle", -1) - if h < 0 or (not self._capture_as_const and self._tinfo[h].exported): - h, info = self._new_handle() - name = AutoNaming.gen_name(x) - info.name = name - info.external = True - info.device = x.device - info.dtype = x.dtype - info.shape = x.shape - if self._capture_as_const: - info.bound_data = RawTensor( - x.numpy(), x.dtype, x.device, False, name - ) + if self._arg_bindings is None: + self._arg_bindings = [ + ("arg_{}".format(i), arg._tuple_shape) for i, arg in enumerate(args) + ] - ihandles.append(h) - - ohandles = [] - for x in outputs: - h, info = self._new_handle() - ohandles.append(h) - info.external = False - x._mixin_handle = h - x._recording = True - x._trace_mixin_info = info - self._active_tensors.add(TensorWeakRef(x)) - if self._symbolic: - self._lazy_eval_tensors.add(TensorWeakRef(x)) - - self._seq.append((op, tuple(ihandles), tuple(ohandles))) - - def _record_const(self, outputs): - if skip_tracing: - (x,) = outputs - h = getattr(x, "_mixin_handle", -1) - if h >= 0: - self._tinfo[h].data_read = True - return - - (x,) = outputs - h, info = self._new_handle() - ohandles = [h] - info.external = True - info.device = x.device - info.dtype = x.dtype - info.shape = x.shape - info.bound_data = x - info.bound_data_numpy = None - info.is_const = True - x._mixin_handle = h - x._recording = True - x._trace_mixin_info = info - if self._symbolic: - self._lazy_eval_tensors.add(TensorWeakRef(x)) - self._seq.append(("Const", tuple(), tuple(ohandles))) - - def _set_active(self, active: bool): - global active_trace - if active: - if active_trace: - raise NotImplementedError("sorry, not implemented: nested trace") - active_trace = self - else: - assert active_trace is self - active_trace = None + if self._kwarg_bindings is None: + self._kwarg_bindings = { + "kwarg_{}".format(k): (k, kwarg._tuple_shape) + for k, kwarg in kwargs.items() + if isinstance(kwarg, RawTensor) + } - def _init_trace(self, symbolic: bool): - if symbolic: - self._lazy_eval_graph = G.Graph() - self._apply_graph_options(self._lazy_eval_graph) - self._lazy_eval_links = () - - def _take_escaped_tensors(self): - escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors)) - self._active_tensors.clear() - return escaped_tensors - - def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links): - lazy_eval_tensors = [x() for x in lazy_eval_tensors] - lazy_eval_tensors = [x for x in lazy_eval_tensors if x is not None] - readers = [G.OutputNode(x._varnode).outputs[0] for x in lazy_eval_tensors] - self._apply_graph_options(lazy_eval_graph) - lazy_eval_graph.options.graph_opt_level = self._graph_opt_level - lazy_eval_graph._set_priority_to_id([*lazy_eval_links, *readers]) - lazy_eval_graph.compile(*lazy_eval_links, *readers) - self._execute_graph(lazy_eval_graph) - lazy_eval_graph.wait() - for r, x in zip(readers, lazy_eval_tensors): - # get values from lazy_eval_graph and assign to lazy_eval tensor - x._handle = RawTensor(r.op.get_value())._handle - x._reset_varnode() - - @contextlib.contextmanager - def _setup(self): - interrupted = False - - def do_enter(): - set_tracing() - self._save_symbolic_shape = set_symbolic_shape(self._symbolic_shape) - self._set_active(True) - if self._untraced: - self._init_trace(self._symbolic) - else: - if self._graph is None: - self._compile() - self._execute_graph(self._graph) - - def do_finalize(): - escaped_tensors = self._take_escaped_tensors() - if self._untraced: - if self._record_only: - self._lazy_eval_graph = None - self._lazy_eval_tensors = None - self._lazy_eval_links = None - else: - for x in escaped_tensors: - if x(): - info = self._tinfo[x()._mixin_handle] - info.data_read = True - x()._mixin_handle = -1 - x()._recording = False - if self._inputs_to_restore: - for x in self._inputs_to_restore: - x._mixin_handle = -1 - x._recording = False - if self._symbolic and ( - self._lazy_eval_tensors or self._lazy_eval_links - ): - # eval lazy eval tensors - self._lazy_eval( - self._lazy_eval_graph, - self._lazy_eval_tensors, - self._lazy_eval_links, - ) - self._lazy_eval_graph = None - self._lazy_eval_tensors = None - self._lazy_eval_links = None - self._untraced = False - else: - # compiled_tensor leaks - if self._pc == len(self._seq): - for x in escaped_tensors: - try: - x().__init__(RawTensor(x()._dev_tensor())) - except RuntimeError: - # TraceMismatchError thrown in do_exit - pass - self._graph.wait() - self._reset_exec_env() - - # reset status - self._pc = 0 - self._tensor_remaps = None - self._set_active(False) - set_symbolic_shape(self._save_symbolic_shape) - unset_tracing() - - def do_exit(): - unset_tracing() - if not self._untraced and self._pc != len(self._seq): - raise TraceMismatchError("premature end") - if not self._symbolic or not self._untraced: - # reset output tensors - for x in self._active_tensors.copy(): - strong_x = x() - if strong_x is not None: - strong_x._dev_tensor() - strong_x._reset_varnode() - strong_x._mixin_handle = -1 - strong_x._recording = False - strong_x._trace_mixin_info = None - - try: - do_enter() - yield - do_exit() - except: - interrupted = True - raise - finally: - do_finalize() - if interrupted: - self._reset() + def _process_outputs(self, outputs): + if isinstance(outputs, RawTensor): + outputs = [outputs] + if isinstance(outputs, collections.abc.Mapping): + output_names, outputs = zip(*sorted(outputs.items())) + else: + # output_names = ["output_{}".format(i) for i in range(len(outputs))] + output_names = None + self._output_names = output_names + for i, output in enumerate(outputs): + name_tensor("output_{}".format(i), output) + if self._output_bindings is None: + self._output_bindings = ["output_{}".format(i) for i in range(len(outputs))] def _begin_excluded_region(self): - if self._capture_as_const: - raise RuntimeError( - "exclude_from_trace cannot be used with capture_as_const" - ) - if self._untraced: - # conditionally reading a compiled tensor in excluded region - # is permitted, so we have to assume every tensor might be read - for x in self._active_tensors: - strong_x = x() - if strong_x: - info = self._tinfo[strong_x._mixin_handle] - info.exported = True - info.data_read = True - else: - for x in self._active_tensors: - strong_x = x() - if strong_x: - strong_x._dev_tensor() - - def _apply_graph_options(self, graph): - - graph.options.no_force_inplace = True - graph.options.seq_opt.enable_seq_comp_node_opt = False - graph.options.graph_opt_level = self._graph_opt_level - if self._dtr_config is not None: - graph.options.enable_dtr_memory_opt = True - graph.options.dtr_config.eviction_threshold = ( - self._dtr_config.eviction_threshold - ) - graph.options.dtr_config.evictee_minimum_size = ( - self._dtr_config.evictee_minimum_size - ) - graph.options.dtr_config.recomp_memory_factor = ( - self._dtr_config.recomp_memory_factor - ) - graph.options.dtr_config.recomp_time_factor = ( - self._dtr_config.recomp_time_factor - ) - # graph optimization - if self._graph_opt_config is not None: - mapping = {None: 0, False: 1, True: 2} - jit_config = graph.options.graph_opt.jit_config - jit_config.fuse_dimshuffle = mapping[ - self._graph_opt_config.jit_fuse_dimshuffle - ] - jit_config.fuse_reduce = mapping[self._graph_opt_config.jit_fuse_reduce] - # sublinear - if self._sublinear_memory_config is not None: - graph.options.enable_sublinear_memory_opt = True - sublinear_config = graph.options.sublinear_mem_config - sublinear_config.lb_memory_mb = self._sublinear_memory_config.lb_memory_mb - sublinear_config.genetic_nr_iter = ( - self._sublinear_memory_config.genetic_nr_iter - ) - sublinear_config.genetic_pool_size = ( - self._sublinear_memory_config.genetic_pool_size - ) - sublinear_config.thresh_nr_try = self._sublinear_memory_config.thresh_nr_try - sublinear_config.num_worker = self._sublinear_memory_config.num_worker - # profile - if self._profiling: - self._profiler = GraphProfiler(graph) - self._profiler2 = None - if int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")): - graph.options.var_sanity_check_first_run = False - - def _execute_graph(self, graph: G.Graph, *args): - if is_profiling() and (self._profiler2 is None): - self._profiler2 = GraphProfiler2(graph) - elif not is_profiling() and (self._profiler2 is not None): - self._profiler2 = None - graph.execute(*args) - - def _compile(self): - graph = self._graph = G.Graph() - graph.options.async_exec_level = 0b100 - self._apply_graph_options(graph) - need_reset_nodes = self._need_reset_nodes = [] - # links enforce ordering of I/O nodes - in_out_links = () - io_links = () - readers = [] - - if self._capture_as_const: - for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()): - info = self._tinfo[h] - opnode = info.data_setter = G.InputNode( - device=info.device, - dtype=info.dtype, - shape=info.shape or (1,), - graph=graph, - use_static_shape=_input_node_use_static_shape(), - ) - need_reset_nodes.append(opnode) - info.varnode = opnode.outputs[0] - in_out_links += opnode.outputs[1:] - - for op, ihandles, ohandles in self._seq: - if isinstance(op, str) and op == "Const": - assert len(ihandles) == 0 - (h,) = ohandles - info = self._tinfo[h] - if not hasattr(info, "varnode"): - assert info.external - assert info.bound_data - info.varnode = graph.make_const( - info.get_numpy(), info.bound_data.dtype, info.bound_data.device, - ) - continue - - require_links = type(op) in _io_op_types - ivars = [] - for i, h in enumerate(ihandles): - info = self._tinfo[h] - if not hasattr(info, "varnode"): - assert info.external - if info.bound_data: - if getattr(info, "is_const", False): - info.varnode = graph.make_const( - info.get_numpy(), - info.bound_data.dtype, - info.bound_data.device, - ) - else: - info.varnode = graph.make_const( - info.bound_data._dev_tensor() - # info.bound_data.numpy() - ) - else: - opnode = info.data_setter = G.InputNode( - *in_out_links, - device=info.device, - dtype=info.dtype, - shape=info.shape or (1,), - graph=graph, - use_static_shape=_input_node_use_static_shape(), - ) - need_reset_nodes.append(opnode) - info.varnode, *in_out_links = opnode.outputs - if require_links and i == 0 and len(io_links) > 0: - opnode = G.VirtualDepNode( - [info.varnode, *io_links], str(io_links[0].device) - ) - info.varnode = opnode.outputs[0] - io_links = (info.varnode,) - - ivars.append(info.varnode) - - ovars = G.apply_normal_varnode(op, *ivars) - - if require_links and len(ovars) > 0: - io_links = (ovars[0],) - assert len(ovars) == len(ohandles) - for h, v in zip(ohandles, ovars): - info = self._tinfo[h] - info.varnode = v - - def add_reader(opnode): - nonlocal in_out_links - need_reset_nodes.append(opnode) - readers.append(opnode.outputs[0]) - in_out_links = opnode.outputs - - if info.data_read: - # Shape can be obtained from data so doesn't need its own - # output node. On the other hand, value is read separately - # to leverage eager h2d copy - info.shape_read = False - opnode = info.data_reader = G.OutputNode(v, *in_out_links) - add_reader(opnode) - if info.value_read: - opnode = info.value_reader = G.ValueOutputNode(v, *in_out_links) - add_reader(opnode) - if info.shape_read: - opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links) - add_reader(opnode) - - graph.options.graph_opt_level = self._graph_opt_level - graph._set_priority_to_id([*readers, *in_out_links, *io_links]) - graph.compile(*readers, *in_out_links, *io_links) - - def _reset_exec_env(self): - for opnode in self._need_reset_nodes: - opnode.reset() + self._trace.begin_excluded_region() - def __call__(self, *args, **kwargs): - with self._setup(): - if self._capture_as_const: - self._process_inputs(*args, **kwargs) - outputs = self.__wrapped__(*args, **kwargs) - if self._capture_as_const: - self._process_outputs(outputs) - return outputs + def _end_excluded_region(self): + self._trace.end_excluded_region() def _make_feed( self, @@ -1115,13 +668,11 @@ class trace: raise ValueError( "you must specify capture_as_const=True at __init__ to use dump" ) - if self._untraced and len(self._seq) == 0: - raise RuntimeError("should do record first before dump") if self._output_names and output_names: raise TypeError( "cannot specify output_names when output is already in dict format" ) - if output_names and not isinstance(output_names, collections.abc.Sequence): + if output_names and isinstance(output_names, str): output_names = (output_names,) if output_names and len(output_names) != len(self._output_bindings): raise ValueError( @@ -1129,11 +680,12 @@ class trace: len(self._output_bindings) ) ) - without_arg_names = arg_names is None - if without_arg_names: + prefer_input_names = arg_names is not None + if arg_names is None: arg_names = ["arg_%d" % i for i in range(len(self._arg_bindings))] - if arg_names and not isinstance(arg_names, collections.abc.Sequence): + if isinstance(arg_names, str): arg_names = (arg_names,) + arg_names = [arg_name if arg_name is not None else "" for arg_name in arg_names] if arg_names and len(arg_names) != len(self._arg_bindings): raise ValueError( "wrong number of arg_names, should be {} values".format( @@ -1142,89 +694,31 @@ class trace: ) output_names = output_names or self._output_names - def dumped_device(info): - device_name = info.device.logical_name - if device_name[:3] in ("cpu", "gpu", "xpu"): - return as_device("xpux") - return info.device - - h2v = {} - graph = G.Graph() + if output_names is None: + output_names = [""] * len(self._output_bindings) + # output_names = ["output_{}".format(i) for i in range(len(self._output_bindings))] - # apply graph_opt_level in dump - if self._graph_opt_level is not None: - graph.options.graph_opt_level = self._graph_opt_level - for i, h in enumerate(self._arg_bindings): - info = self._tinfo[h] - h2v[h] = graph.make_h2d( - dtype=info.dtype, - device=dumped_device(info), - shape=info.shape or (1,), - name=info.name if without_arg_names and info.name else arg_names[i], - ) - for k, h in self._kwarg_bindings.items(): - info = self._tinfo[h] - h2v[h] = graph.make_h2d( - dtype=info.dtype, - device=dumped_device(info), - shape=info.shape or (1,), - name=k, - ) + input_bindings = [] - for op, ihandles, ohandles in self._seq: - if isinstance(op, str) and op == "Const": - assert len(ihandles) == 0 - (h,) = ohandles - info = self._tinfo[h] - if h not in h2v: - assert info.external - assert info.bound_data - h2v[h] = graph.make_const( - info.get_numpy(), - dtype=info.dtype, - device=dumped_device(info), - name=info.name, - ) - continue - ivars = [] - for h in ihandles: - info = self._tinfo[h] - if h not in h2v: - assert info.external - assert info.bound_data - h2v[h] = graph.make_const( - info.get_numpy(), - dtype=info.dtype, - device=dumped_device(info), - name=info.name, - ) - ivars.append(h2v[h]) - if isinstance(op, BatchNorm): - assert ( - op.fwd_mode == BatchNorm.FwdMode.INFERENCE - ), "can not dump BatchNorm in training mode, maybe you forget to do model.eval()?" - ovars = G.apply_normal_varnode(op, *ivars) + def normalize_shape(shape): + return (1,) if shape == () else shape - AutoNaming.record_opnode(ovars[0].op) + for arg_name, (arg_id, arg_shape) in zip(arg_names, self._arg_bindings): + input_bindings.append((arg_id, arg_name, normalize_shape(arg_shape))) - assert len(ovars) == len(ohandles) - h2v.update(zip(ohandles, ovars)) + for kwarg_id, (kwarg_name, kwarg_shape) in self._kwarg_bindings.items(): + input_bindings.append((kwarg_id, kwarg_name, normalize_shape(kwarg_shape))) - for i in ohandles: - name = AutoNaming.get_var_name(i) - if name is not None: - h2v[i].name = name - - AutoNaming.remove_duplicate_names() + graph = G.Graph() - dest_vars = [] - for i, h in enumerate(self._output_bindings): - v = h2v[h] - if output_names: - v.name = output_names[i] - dest_vars.append(v) + dest_vars = self._trace.dump( + graph, + input_bindings, + [*zip(self._output_bindings, output_names)], + prefer_input_names, + ) - dest_vars = [i._node for i in dest_vars] + # dest_vars = [i._node for i in dest_vars] if input_data is not None: feeds = self._make_feed( @@ -1260,7 +754,7 @@ class trace: file = open(file, permission) if keep_opr_priority: - graph._set_priority_to_id(dest_vars) + _set_priority_to_id(dest_vars) if input_data is not None: file.write(b"mgbtest0") @@ -1307,300 +801,5 @@ class trace: return dump_info - def _process_inputs(self, *args, **kwargs): - if self._untraced: - self._inputs_to_restore = [] - - def record_input(x): - if x is None: - return - h, info = self._new_handle() - info.external = False - info.name = x.c_name - info.device = x.device - info.dtype = x.dtype - info.shape = x.numpy().shape - x._mixin_handle = h - x._recording = True - x._trace_mixin_info = info - self._inputs_to_restore.append(x) - return h - - self._arg_bindings = [] - for i, x in enumerate(args): - if not isinstance(x, RawTensor): - raise TypeError( - "positional arguments should all be tensor " - "but args[%d] cannot be recognized as one" % i - ) - self._arg_bindings.append(record_input(x)) - - self._kwarg_bindings = {} - for k, x in kwargs.items(): - if isinstance(x, RawTensor): - self._kwarg_bindings[k] = record_input(x) - else: - if len(args) != len(self._arg_bindings): - raise TraceMismatchError("positional argument length mismatch") - - self._tensor_remaps = {} - - for i, (h, x) in enumerate(zip(self._arg_bindings, args)): - if not isinstance(x, RawTensor): - raise TypeError( - "positional arguments should all be tensor " - "but args[%d] cannot be recognized as one" % i - ) - info = self._tinfo[h] - if x.dtype != info.dtype: - raise TypeError("args[%d].dtype different from last time" % i) - if x.device != info.device: - raise TypeError("args[%d].device different from last time" % i) - info.data_setter.set_value(x._dev_tensor()) - self._tensor_remaps[x._handle] = CompiledTensorProxy(h) - - kwargs_tensors = {} - for k, x in kwargs.items(): - if isinstance(x, RawTensor): - kwargs_tensors[k] = x - if set(kwargs_tensors) != set(self._kwarg_bindings): - too_many = set(kwargs_tensors) - set(self._kwarg_bindings) - too_few = set(self._kwarg_bindings) - set(kwargs_tensors) - if too_many: - raise TraceMismatchError( - "keyword arguments found to be tensor this time " - "but were non-tensor previously: %s" % " ".join(too_many) - ) - if too_few: - raise TraceMismatchError( - "keyword arguments found to be non-tensor this time " - "but were tensor previously: %s" % " ".join(too_few) - ) - for k, h in self._kwarg_bindings.items(): - x = kwargs_tensors[k] - info = self._tinfo[h] - if x.dtype != info.dtype: - raise TypeError("kwargs[%s].dtype different from last time" % k) - if x.device != info.device: - raise TypeError("kwargs[%s].device different from last time" % k) - info.data_setter.set_value(x._dev_tensor()) - self._tensor_remaps[x._handle] = CompiledTensorProxy(h) - - def _process_outputs(self, outputs): - output_names = None - if isinstance(outputs, collections.abc.Mapping): - output_names, outputs = zip(*sorted(outputs.items())) - elif not isinstance(outputs, collections.abc.Sequence): - outputs = (outputs,) - - if not self._untraced: - if output_names != self._output_names: - too_many = set(output_names) - set(self._output_names) - too_few = set(self._output_names) - set(output_names) - if too_many: - raise TraceMismatchError( - "output has more keys than last time: %s" % " ".join(too_many) - ) - if too_few: - raise TraceMismatchError( - "output has less keys than last time: %s" % " ".join(too_few) - ) - if len(outputs) != len(self._output_bindings): - raise TraceMismatchError("output size differs from last time") - else: - self._output_names = output_names - self._output_bindings = [] - - for i, x in enumerate(outputs): - if not isinstance(x, RawTensor): - raise TypeError("every item of return value should be tensor") - if self._untraced: - h = x._mixin_handle - if h < 0: - raise RuntimeError("output is not computed from inputs") - self._output_bindings.append(h) - else: - h = x._mixin_handle - if h not in self._output_handles: - raise RuntimeError("output is not computed from inputs") - if h != self._output_bindings[i]: - raise TraceMismatchError( - "retval[%s] is a different tensor than last time" - % (output_names and output_names[i] or i) - ) - def get_profile(self): - r"""Get profiling result for compiled trace. - - Return: - a json compatible object. - """ - if not self._profiler: - raise RuntimeError("trace is not set with profiling=True") - return json.loads(self._profiler.get()) - - -class CompiledTensorProxy: - r"""Duck-typed RawTensor""" - - def __init__(self, handle): - self.__handle = handle - self._isscalar = False - self.__info = active_trace._tinfo[handle] - self.__shape = None - self.__data = None - self.__value = None - - @property - def dtype(self): - return self.__info.varnode.dtype - - @property - def device(self): - return self.__info.varnode.device - - @property - def shape(self): - if self._isscalar: - return () - if self.__shape is None: - if self.__info.shape_read: - self.__shape = self.__info.shape_reader.get_value().shape - elif self.__info.data_read: - self.__shape = self._dev_tensor().shape - else: - # c++ will throw TraceReadError - return None - return self.__shape - - def numpy(self): - if self.__value is None: - if self.__info.value_read: - self.__value = self.__info.value_reader.get_value() - elif self.__info.data_read: - self.__value = self._dev_tensor().numpy() - else: - # c++ will throw TraceReadError - return None - # c++ side will handle scalar case - return self.__value - - def _dev_tensor(self): - if self.__data is None: - if not self.__info.data_read: - # c++ will throw TraceReadError - return None - self.__data = self.__info.data_reader.get_value() - return self.__data - - def __del__(self): - if self.__info.shape_read and self.__shape is not None: - self.__info.shape_reader.drop_value() - if self.__info.value_read and self.__value is not None: - self.__info.value_reader.drop_value() - if self.__info.data_read and self.__data is not None: - self.__info.data_reader.drop_value() - - -def apply_symbolic_mode(op: OpDef, *args: RawTensor): - graph = active_trace._lazy_eval_graph - ivars = [] - for x in args: - var = getattr(x, "_varnode", None) - if var: - ivars.append(var) - else: - data_setter = G.InputNode( - device=x.device, - dtype=x.dtype, - shape=x.numpy().shape or (1,), - graph=graph, - use_static_shape=True, - ) - var = data_setter.outputs[0] - ivars.append(var) - data_setter.set_value(x._dev_tensor()) - - require_links = type(op) in _io_op_types - - if require_links and active_trace._lazy_eval_links: - assert len(ivars) > 0, "op should has at least one input" - opnode = G.VirtualDepNode( - [ivars[0], *active_trace._lazy_eval_links], - str(active_trace._lazy_eval_links[0].device), - ) - ivars[0] = opnode.outputs[0] - active_trace._lazy_eval_links = (ivars[0],) - - ovars = G.apply_normal_varnode(op, *ivars) - outputs = [RawTensor(o) for o in ovars] - - if require_links: - active_trace._lazy_eval_links = (G.VarNode(outputs[0]._varnode),) - - return outputs - - -def apply_const_symbolic_mode(value, dtype, device, name): - graph = active_trace._lazy_eval_graph - # don't need to unset tracing - # because varnode construction will ignore tracing flag - ret = RawTensor(graph.make_const(value, dtype=dtype, device=device, name=name)) - if np.array(value).ndim == 0: - setscalar(ret) - return (ret,) - - -def apply_compiled_mode(op: OpDef, *args: RawTensor): - if skip_tracing: - args = [ - RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x - for x in args - ] - unset_tracing() - ret = apply(op, *args) - set_tracing() - return ret - return active_trace._apply_op(op, args) - - -def apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name): - if skip_tracing: - unset_tracing() - ret = RawTensor(value, dtype, device, False, name) - set_tracing() - return ret - return active_trace._apply_const(value, dtype, device) - - -def apply_with_tracing(op: OpDef, *args: RawTensor): - if active_trace._graph: - # if member _graph exits, then is_compiled - return apply_compiled_mode(op, *args) - if hasattr(op, "scope"): - op.scope = AutoNaming.get_scope() - if active_trace._symbolic: - outputs = apply_symbolic_mode(op, *args) - else: - unset_tracing() - outputs = apply(op, *args) - set_tracing() - - active_trace._record_op(op, args, outputs) - return list(outputs) - - -def apply_const_with_tracing(value, dtype, device, is_const, no_cache, name): - if active_trace._graph: - return apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name) - if active_trace._symbolic: - outputs = apply_const_symbolic_mode(value, dtype, device, name) - else: - unset_tracing() - outputs = RawTensor(value, dtype, device, False, name) - if np.array(value).ndim == 0: - setscalar(outputs) - outputs = (outputs,) - set_tracing() - active_trace._record_const(outputs) - return list(outputs) + return json.loads(self._trace.get_profile()) diff --git a/imperative/python/megengine/module/module.py b/imperative/python/megengine/module/module.py index e4735938..717908de 100644 --- a/imperative/python/megengine/module/module.py +++ b/imperative/python/megengine/module/module.py @@ -111,6 +111,7 @@ class Module(metaclass=ABCMeta): # used for profiler and automatic naming self._name = None + self._short_name = None @abstractmethod def forward(self, inputs): @@ -137,7 +138,7 @@ class Module(metaclass=ABCMeta): return HookHandler(self._forward_hooks, hook) def __call__(self, *inputs, **kwargs): - AutoNaming.push_scope(self.name if self.name is not None else self._name) + AutoNaming.push_scope(self.name if self.name is not None else self._short_name) for hook in self._forward_pre_hooks.values(): modified_inputs = hook(self, inputs) if modified_inputs is not None: @@ -641,15 +642,43 @@ class Module(metaclass=ABCMeta): else: if modules is not None and name in modules: modules.remove(name) - for k, v in _expand_structure(name, value): - if not v._name: - v._name = k - elif v._name != k: + + def append_name(prefix, name): + if prefix is None or prefix == "": + return name + return prefix + "." + name + + def set_name(parent, prefix, name, obj): + if isinstance(obj, Tensor): + assert obj.name is not None + if obj.name != "": + name = obj.name + full_name = append_name(prefix, name) + if obj._short_name and obj._short_name != name: logger.warning( "try setting the submodule `{}` to `{}`'s new attribute `{}`, its name `{}` will remain unchanged".format( - type(v), type(self), k, v._name + obj._short_name, type(parent), name, obj._short_name ) ) + return + if isinstance(obj, Tensor): + obj._prefix = prefix + obj._name = full_name + obj._short_name = name + obj._set_name(obj._name) + return obj._name + elif isinstance(obj, Module): + obj._name = full_name + obj._short_name = name + for k, v in obj._flatten(recursive=False, with_key=True): + set_name(obj, full_name, k, v) + return obj._name + else: + assert False + + for k, v in _expand_structure(name, value): + prefix = self._name if self._name else self.name + set_name(self, prefix, k, v) super().__setattr__(name, value) def __delattr__(self, name: str): diff --git a/imperative/python/megengine/random/rng.py b/imperative/python/megengine/random/rng.py index e95eb5d5..050cf5e9 100644 --- a/imperative/python/megengine/random/rng.py +++ b/imperative/python/megengine/random/rng.py @@ -14,6 +14,7 @@ from numpy.random import MT19937 from .. import Tensor from ..core._imperative_rt.core2 import apply +from ..core._imperative_rt.core2 import sync as _sync from ..core._imperative_rt.ops import delete_rng_handle as _delete_rng_handle from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed from ..core._imperative_rt.ops import ( @@ -650,6 +651,10 @@ class RNG: def __del__(self): if self._handle != 0: + # RNG op might execute after handle released due to async dispatch, so + # we need sync before delete a handle to avoid memory leak or + # use-after-free + _sync() _delete_rng_handle(self._handle) diff --git a/imperative/python/megengine/tensor.py b/imperative/python/megengine/tensor.py index 955efaa8..e06aa9d3 100644 --- a/imperative/python/megengine/tensor.py +++ b/imperative/python/megengine/tensor.py @@ -12,7 +12,7 @@ import numpy as np from .core._imperative_rt import CompNode from .core._imperative_rt.core2 import Tensor as _Tensor -from .core._imperative_rt.core2 import apply +from .core._imperative_rt.core2 import apply, set_py_tensor_type from .core._trace_option import use_symbolic_shape from .core._wrap import as_device from .core.ops.builtin import Copy, GetVarShape @@ -20,7 +20,6 @@ from .core.tensor.array_method import ArrayMethodMixin from .device import _valid_device, get_default_device from .logger import get_logger from .utils.deprecation import deprecated -from .utils.naming import AutoNaming logger = get_logger(__name__) @@ -40,6 +39,10 @@ class Tensor(_Tensor, ArrayMethodMixin): grad = None dmap_callback = None _qparams = None + _custom_name = "" + _name = None + _short_name = None + _prefix = None def __new__( cls, @@ -81,9 +84,15 @@ class Tensor(_Tensor, ArrayMethodMixin): device: str = None, is_const: bool = False, no_cache: bool = False, - name: str = None, + name: str = "", ): - pass + if name is None: + name = "" + self._custom_name = name + self._name = name + self._short_name = name + self._set_name(self._name) + self._prefix = None @property def shape(self) -> Union[tuple, "Tensor"]: @@ -151,12 +160,13 @@ class Tensor(_Tensor, ArrayMethodMixin): @property def name(self): - return self.c_name + return self._custom_name @name.setter def name(self, name): - self.c_name = name - AutoNaming.record_var_name(self._mixin_handle, name) + self._custom_name = name + self._name = self._prefix + "." + name if self._prefix else name + self._set_name(self._name) @deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0") def set_value(self, value): @@ -224,6 +234,9 @@ class Tensor(_Tensor, ArrayMethodMixin): self._qparams = qparams +set_py_tensor_type(Tensor) + + tensor = Tensor diff --git a/imperative/python/megengine/traced_module/__init__.py b/imperative/python/megengine/traced_module/__init__.py index 6bbdc668..9568ceb5 100644 --- a/imperative/python/megengine/traced_module/__init__.py +++ b/imperative/python/megengine/traced_module/__init__.py @@ -6,7 +6,6 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from ..core._imperative_rt.core2 import set_cpp_apply_module_trace from . import compat from ._passes import optimize from .pytree import register_supported_type @@ -14,14 +13,12 @@ from .tm_config import disable_default_checker, enable_expr_checker from .traced_module import ( TracedModule, _register_all_builtin_module, - cpp_apply_module_trace, register_as_builtin, trace_module, wrap, ) _register_all_builtin_module() -set_cpp_apply_module_trace(cpp_apply_module_trace) __all__ = [ "register_as_builtin", diff --git a/imperative/python/megengine/traced_module/checker.py b/imperative/python/megengine/traced_module/checker.py index 31fa0470..31822b59 100644 --- a/imperative/python/megengine/traced_module/checker.py +++ b/imperative/python/megengine/traced_module/checker.py @@ -13,7 +13,6 @@ import numpy as np from ..core._imperative_rt.core2 import apply from ..core._imperative_rt.ops import ROIAlign, ROIPooling from ..core.ops.builtin import Copy -from ..core.tensor.utils import isscalar, setscalar from ..tensor import Tensor from .tm_config import _exclude_from_trace @@ -70,8 +69,6 @@ class TracedModuleChecker: self.current_node2values()[node] = apply( Copy(comp_node=value.device), value )[0] - if isscalar(value): - setscalar(self.current_node2values()[node]) def check_apply_special_cases(self, opdef, num_outputs): indexs = list(range(num_outputs)) diff --git a/imperative/python/megengine/traced_module/expr.py b/imperative/python/megengine/traced_module/expr.py index c22249fc..cc486b0a 100644 --- a/imperative/python/megengine/traced_module/expr.py +++ b/imperative/python/megengine/traced_module/expr.py @@ -20,6 +20,7 @@ from ..core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import ( apply, is_tracing_module, + set_module_trace_hook, set_module_tracing, unset_module_tracing, ) @@ -605,8 +606,7 @@ class Apply(Expr): def apply_module_trace_hook(cls, opdef, *inputs): for i in inputs: node = NodeMixin.get(i, None) - if node is None: # capture as constant - NodeMixin.wrap_safe(i, Constant.make(i)) + assert node is not None if isinstance(opdef, FakeQuant): inp_nodes = [NodeMixin.get(inputs[0])] @@ -805,3 +805,12 @@ class Constant(Expr): if isinstance(v, _ModuleState): state[k] = v.to_module() self.__dict__.update(state) + + +def _module_trace_capture(value): + node = Constant.make(value) + NodeMixin.wrap_safe(value, node) + return node + + +set_module_trace_hook(Apply.apply_module_trace_hook) diff --git a/imperative/python/megengine/traced_module/module_tracer.py b/imperative/python/megengine/traced_module/module_tracer.py index 70a020f4..1e81d7e5 100644 --- a/imperative/python/megengine/traced_module/module_tracer.py +++ b/imperative/python/megengine/traced_module/module_tracer.py @@ -101,9 +101,7 @@ BUILTIN_TENSOR_WRAP_METHOD = [ "requires_grad", "_reset", "_isscalar", - "_setscalar", "_tuple_shape", - "_unsetscalar", ] diff --git a/imperative/python/megengine/traced_module/traced_module.py b/imperative/python/megengine/traced_module/traced_module.py index 670ab7e9..58f7bf52 100644 --- a/imperative/python/megengine/traced_module/traced_module.py +++ b/imperative/python/megengine/traced_module/traced_module.py @@ -43,7 +43,6 @@ from ..core._imperative_rt.core2 import ( ) from ..core._trace_option import set_symbolic_shape from ..core.ops.builtin import Copy -from ..core.tensor.utils import isscalar, setscalar from ..module import Module from ..module import external as MExternal from ..module.qat import QATModule @@ -1295,12 +1294,9 @@ def _wrapped_function(orig_func): return orig_func(*args, **kwargs) if isinstance(args[1], RawTensor): node = NodeMixin.get(inputs[1]) - is_scalar = isscalar(inputs[1]) inputs[1] = apply( Copy(comp_node=inputs[1].device), Tensor(inputs[1]) )[0] - if is_scalar: - setscalar(inputs[1]) # copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, # which will cause they have same _NodeMixin__node in tracing. NodeMixin.wrap_safe(inputs[1], node) @@ -2468,8 +2464,8 @@ def trace_module( try: net_name = mod._name if mod._name else mod.__class__.__name__ use_sym_shape = set_symbolic_shape(True) - set_module_tracing() set_active_module_tracer(module_tracer(_wrapped_function)) + set_module_tracing() for cls in [Expr, Node]: cls._set_next_id(0) with active_module_tracer().patcher: @@ -2518,9 +2514,9 @@ def trace_module( return traced_mod finally: set_symbolic_shape(use_sym_shape) - set_active_module_tracer(None) unset_module_tracing() for t in mod.tensors(recursive=True): NodeMixin.clear_node(t) for t in inputs: NodeMixin.clear_node(t) + set_active_module_tracer(None) diff --git a/imperative/python/megengine/utils/profiler.py b/imperative/python/megengine/utils/profiler.py index 0af5f174..702a8e37 100644 --- a/imperative/python/megengine/utils/profiler.py +++ b/imperative/python/megengine/utils/profiler.py @@ -137,6 +137,11 @@ class Profiler(ContextDecorator): get_logger().info("process {} generating {}".format(self._pid, format)) self._dump_callback(path, format) get_logger().info("profiling results written to {}".format(path)) + if os.path.getsize(path) > 64 * 1024 * 1024: + get_logger().warning( + "profiling results too large, maybe you are profiling multi iters," + "consider attach profiler in each iter separately" + ) self._dump_callback = None _living_profilers.remove(self) diff --git a/imperative/python/src/grad.cpp b/imperative/python/src/grad.cpp index a1a16117..6680b721 100644 --- a/imperative/python/src/grad.cpp +++ b/imperative/python/src/grad.cpp @@ -9,9 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#pragma GCC diagnostic ignored "-Wmissing-field-initializers" - #include "./grad.h" + #include "megbrain/imperative/backward_graph_opt.h" #include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/proxy_graph_detail.h" @@ -19,465 +18,19 @@ #include "range/v3/all.hpp" +#include "./transformation.h" + namespace py = pybind11; namespace views = ranges::views; namespace mgb::imperative::python { -using scoped_disable = ApplyContext::scoped_disable; -using Flags = Tensor::Flags; - namespace { - -struct GradSlotWeakPtr { - std::weak_ptr grad_fn; - size_t idx; -}; - -std::shared_ptr make_backward_graph( - ApplyContext& ctx, const apply_result_t& outputs) { - // hash - using OptimizedBackwardGraphCache = OpMethResultCache< - std::shared_ptr, SmallVector>; - thread_local OptimizedBackwardGraphCache cache; - decltype(cache)::key_t cache_key{ctx.op}; - SmallVector& input_descs = cache_key.inputs; - SmallVector& input_requires_grad = std::get<0>(cache_key.extras); - input_descs.resize(ctx.nargs); - input_requires_grad.resize(ctx.nargs); - for (size_t i = 0; i < ctx.nargs; ++i) { - input_descs[i].layout.dtype = ctx.args[i]->dtype(); - input_descs[i].comp_node = ctx.args[i]->comp_node(); - input_requires_grad[i] = python::input_requires_grad(ctx, i); - } - - auto iter = cache.find(cache_key); - if (iter != cache.end()) { - return iter->second; - } - - // slow path - SmallVector output_has_grad(outputs.size(), true); - std::shared_ptr ret; - auto bg = OpDef::make_backward_graph( - *ctx.op, input_descs, input_requires_grad, output_has_grad); - if (!bg.graph.empty()) { - ret = std::make_shared(bg); - } - cache.emplace(cache_key, ret); - return ret; +std::unordered_map, GradKeyWrapper*> grad_key_map; } -struct BackwardGraphWithClosure { - std::shared_ptr backward_graph; - SmallVector> closure; - size_t output_mask_offset; - size_t grad_mask_offset; - - BackwardGraphWithClosure( - std::shared_ptr backward_graph_, - ApplyContext& ctx, const apply_result_t& outputs) - : backward_graph(backward_graph_), - output_mask_offset(ctx.nargs), - grad_mask_offset(ctx.nargs + outputs.size()) { - // save_for_backward[0:nargs]: - // whether input is kept for backward - // - // save_for_backward[nargs:nargs+outputs.size()]: - // whether output is kept for backward - // - // save_for_backward[-outputs.size():]: - // whether gradient of output can propagate to any input - // - // Example: - // perform c = a * b, with a.requires_grad == True and - // b.requires_grad == False, save_for_backward = [0, 1, 0, 1] - auto& save_for_backward = backward_graph->save_for_backward; - mgb_assert(save_for_backward.size() == ctx.nargs + 2 * outputs.size()); - size_t count = std::count_if( - save_for_backward.begin(), save_for_backward.end(), ranges::identity{}); - if (!backward_graph->precomp.empty()) { - auto&& irng = ranges::span(ctx.args, ctx.nargs); - auto&& orng = views::transform(outputs, [](auto&& i) { return i.get(); }); - auto precomp = apply(backward_graph->precomp, views::concat(irng, orng)); - closure.reserve(precomp.size() + count); - std::copy(precomp.begin(), precomp.end(), std::back_inserter(closure)); - } else { - closure.reserve(count); - } - for (size_t i = 0; i < ctx.nargs; ++i) { - if (save_for_backward[i]) { - closure.push_back(ctx.args[i]->shared_from_this()); - } - } - for (size_t i = 0; i < outputs.size(); ++i) { - if (save_for_backward[ctx.nargs + i]) { - closure.push_back(outputs[i]); - } - } - } - - template - void operator()(BackwardContext&, T&& grads, R&& receiver) { - Tensor* args[closure.size() + grads.size()]; - size_t nargs = 0; - for (auto&& t : closure) { - args[nargs++] = t.get(); - } - bool null_grad = false; - for (size_t i = 0; i < grads.size(); ++i) { - if (backward_graph->save_for_backward[grad_mask_offset + i]) { - if (grads[i]) { - if (null_grad) { - PyErr_SetString(PyExc_NotImplementedError, "report to devs"); - throw py::error_already_set(); - } - args[nargs++] = grads[i]; - } else { - null_grad = true; - } - } - } - if (null_grad) - return; - - auto igrads = apply(backward_graph->backward, args, nargs); - auto&& it = igrads.begin(); - for (auto [i, p] : views::enumerate(backward_graph->input_has_grad)) { - if (p) { - receiver(i, std::move(*it)); - ++it; - } - } - } - - bool input_has_grad(size_t i) { return backward_graph->input_has_grad[i]; } - - bool output_requires_grad(size_t i) { - return backward_graph->save_for_backward[grad_mask_offset + i]; - } - - bool output_captured(size_t i) { - return backward_graph->save_for_backward[output_mask_offset + i]; - } -}; - -struct PythonBackward { - py::object pyfunc; - size_t input_size; - - PythonBackward(py::object f, size_t nin) : pyfunc(f), input_size(nin) {} - - template - void operator()(BackwardContext& ctx, T&& grads, R&& receiver) { - auto args = py::tuple(grads.size()); - for (size_t i = 0; i < grads.size(); ++i) { - auto&& g = grads[i]; - args[i] = g ? ctx.wrap_tensor(g) : py::none(); - } - auto input_grads = py::reinterpret_steal( - PyObject_Call(pyfunc.ptr(), args.ptr(), nullptr)); - if (!input_grads) - throw py::error_already_set(); - if (input_grads.is_none()) - return; - if (auto* tw = TensorWrapper::try_cast(input_grads.ptr())) { - if (input_size != 1) { - throw py::value_error( - "custom grad rule returned wrong number of grads"); - } - if (!ctx.pytype) { - ctx.pytype = Py_TYPE(input_grads.ptr()); - } - receiver(0, tw->m_tensor); - return; - } - if (py::len(input_grads) != input_size) { - throw py::value_error("custom grad rule returned wrong number of grads"); - } - for (auto [i, g] : views::enumerate(input_grads)) { - if (g.is_none()) - continue; - auto* tw = TensorWrapper::try_cast(g.ptr()); - if (!tw) { - throw py::type_error("custom grad rule returned non-tensor"); - } - if (!ctx.pytype) { - ctx.pytype = Py_TYPE(g.ptr()); - } - receiver(i, tw->m_tensor); - } - } - - static constexpr bool input_has_grad(size_t) { return true; } - static constexpr bool output_requires_grad(size_t) { return true; } - static constexpr bool output_captured(size_t) { return true; } -}; - -} // namespace - -struct GradProducerRecord : intrusive_list::Node { - using Base = intrusive_list::Node; - - GradProducerRecord() = default; - GradProducerRecord(GradProducerRecord::head_t& head) - : Base(intrusive_list::after_t{}, head) {} - // GradProducerRecord(GradProducerRecord&&) = default; - // GradProducerRecord& operator=(GradProducerRecord&) = default; - // GradProducerRecord& operator=(GradProducerRecord&&) = default; -}; - -struct GradSlot { - std::shared_ptr grad; - py::object callback; - GradProducerRecord::head_t producer_head; -}; - -struct GradSlotProducerPtr : GradSlotPtr { - GradProducerRecord producer_record; - - GradSlotProducerPtr() = default; - GradSlotProducerPtr(GradInfo& info) - : GradSlotPtr(info), producer_record(info->producer_head) {} -}; - -struct GradFn : std::enable_shared_from_this { - static MemPool pool; - - std::weak_ptr key; - // slots for receiving and accumulating grads - // same length as outputs (of forward op) - SmallVector slots; - // where to send and accumulate grads - // same length as inputs (of forward op) - SmallVector dsts; - // encapsules actual function to compute gradient - std::variant< - std::monostate, BackwardGraphWithClosure, PythonBackward, CustomBackward> - backward; - // a flag used during backward - bool in_ref_keeper = false; - - static void deleter(GradFn* ptr) { pool.free(ptr); } - - static std::shared_ptr make() { - return std::shared_ptr(pool.alloc(), &deleter); - } - - void clear() { - key.reset(); - slots.clear(); - dsts.clear(); - backward.emplace(); - } -}; - -GradSlotPtr::operator bool() const { - return bool(grad_fn); -} - -GradSlot* GradSlotPtr::operator->() { - return &grad_fn->slots[idx]; -} - -namespace { - -class GradFnHelper { - std::shared_ptr grad_fn; - - GradFn* get() { - if (!grad_fn) { - grad_fn = std::make_shared(); - } - return grad_fn.get(); - } - - friend apply_result_t imperative::python::apply_grad(ApplyContext&); - -public: - template - auto& emplace(Args&&... args) { - return get()->backward.emplace(std::forward(args)...); - } - - void reset() { grad_fn = nullptr; } -}; - -apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { - // copy inputs first, or trace will make InputNodes for each usage - ApplyContext ctx_dup = ctx; - SmallVector> inputs_copy; - SmallVector inputs_copy_weak; - for (size_t i = 0; i < ctx.nargs; ++i) { - Tensor* input = ctx.args[i]; - inputs_copy.push_back(python::apply(FastpathCopy::make(), input)[0]); - inputs_copy_weak.push_back(inputs_copy.back().get()); - inputs_copy.back()->m_grad_info_dict = ctx.args[i]->m_grad_info_dict; - if (input->m_flags & Flags::GRAD) { - inputs_copy.back()->m_flags |= Flags::GRAD; - } - } - ctx_dup.args = inputs_copy_weak.data(); - - auto outputs = apply(ctx_dup); - - auto backward_graph = make_backward_graph(ctx_dup, outputs); - if (!backward_graph) { - return outputs; - } - ret_grad_fn.emplace( - std::move(backward_graph), ctx_dup, outputs); - - return outputs; -} - -apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { - auto* op = ctx.op->try_cast_final(); - py::tuple pyin(ctx.nargs); - for (size_t i = 0; i < ctx.nargs; ++i) { - pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this()); - } - auto grad_rule = py::getattr(op->obj, "_grad_rule"); - auto pyret = py::reinterpret_steal( - PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr)); - if (!pyret) - throw py::error_already_set(); - auto [outputs, backward] = py::cast>(pyret); - ret_grad_fn.emplace(std::move(backward), ctx.nargs); - if (auto* tw = TensorWrapper::try_cast(outputs.ptr())) { - return {tw->m_tensor}; - } - apply_result_t ret; - ret.reserve(py::len(outputs)); - for (auto&& i : outputs) { - auto* tw = TensorWrapper::try_cast(i.ptr()); - mgb_assert(tw); - ret.push_back(tw->m_tensor); - } - return ret; -} - -} // namespace - -apply_result_t apply_grad(ApplyContext& ctx) { - std::unordered_set> grad_keys; - for (size_t i = 0; i < ctx.nargs; ++i) { - auto* tensor = ctx.args[i]; - if (!tensor->m_grad_info_dict.empty()) { - size_t grad_cnt = 0; - for (auto&& grad_info : tensor->m_grad_info_dict) { - auto input_grad_key = grad_info.grad_fn->key.lock(); - if (input_grad_key && input_grad_key->active && - !input_grad_key->is_blocked()) { - grad_keys.insert(input_grad_key); - grad_cnt++; - } - } - if (!grad_cnt) { - tensor->m_flags &= ~Flags::GRAD; - } - } else { - tensor->m_flags &= ~Flags::GRAD; - } - } - - ctx.flags &= ~Flags::GRAD; - - if (grad_keys.empty()) { - return apply(ctx); - } else if (grad_keys.size() > 1 && !GradKey::allow_higher_order_directive) { - PyErr_SetString( - PyExc_NotImplementedError, - "second order directive not enabled, please call " - "'megengine.experimental.enable_higher_order_directive'"); - throw pyext17::py_err_set(); - } - - GradFnHelper grad_fn_holder; - auto outputs = [&]() { - auto _ = scoped_disable(Flags::GRAD); - if (ctx.op->same_type()) { - return python_grad_rule(ctx, grad_fn_holder); - } - auto&& registry = grad_rule_registry(); - auto&& it = registry.find(ctx.op->dyn_typeinfo()); - if (it != registry.end()) { - auto&& maker = grad_fn_holder.emplace().maker(ctx); - if (auto ret = it->second(ctx, maker)) { - maker.finalize(); - return *ret; - } - grad_fn_holder.reset(); - } - return backward_graph_grad_rule(ctx, grad_fn_holder); - }(); - - if (!grad_fn_holder.grad_fn) { - return outputs; - } - - for (auto&& grad_key : grad_keys) { - auto grad_fn = std::make_shared(); - grad_fn->backward = grad_fn_holder.grad_fn->backward; - grad_fn->key = grad_key; - grad_fn->slots.resize(outputs.size()); - grad_fn->dsts.reserve(ctx.nargs); - - std::visit( - [&](auto& backward) { - using T = std::decay_t; - if constexpr (std::is_same_v) { - mgb_assert(0); - } else { - for (size_t i = 0; i < ctx.nargs; ++i) { - if (backward.input_has_grad(i) && - input_requires_grad(ctx, i) && - ctx.args[i]->m_grad_info_dict.count(grad_key.get())) { - auto& input_grad_info = - ctx.args[i]->m_grad_info_dict.at( - grad_key.get()); - grad_fn->dsts.emplace_back(input_grad_info); - // register as grad producer - grad_fn->dsts.back().producer_record.insert_after( - input_grad_info->producer_head); - } else { - grad_fn->dsts.emplace_back(); - } - } - for (size_t i = 0; i < outputs.size(); ++i) { - if (backward.output_requires_grad(i)) { - if (backward.output_captured(i)) { - // avoid reference cycle [Tensor <-> GradFn] - static std::shared_ptr op = - std::make_shared(); - outputs[i] = python::apply(op, outputs[i])[0]; - } - // populate grad info of output tensor - auto& grad_info = - outputs[i]->m_grad_info_dict[grad_key.get()]; - grad_info.grad_fn = grad_fn; - grad_info.idx = i; - grad_info.insert_after(grad_key->free_vars_head); - outputs[i]->m_flags |= Flags::GRAD; - } - } - } - }, - grad_fn->backward); - - // record forward history - grad_key->tape.emplace_back(grad_fn); - } - - return outputs; -} - -PyObject* GradKeyWrapper::get_priority() { - return py::cast(m_key->priority).release().ptr(); -} - -void GradKeyWrapper::set_priority(pybind11::handle priority) { - m_key->priority = py::cast(priority); +GradKeyWrapper::GradKeyWrapper() : m_key(std::make_shared()) { + grad_key_map[m_key] = this; } void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) { @@ -488,157 +41,59 @@ void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) { if (!tw) { throw py::type_error("argument 1 must be Tensor"); } - auto* tensor = tw->m_tensor.get(); py::object callback; if (args[1] != Py_None) { callback = py::reinterpret_borrow(args[1]); } - m_key->attach(tensor, std::move(callback)); -} - -//! GradKey is weakly refered by tensor->m_grad_info.grad_fn->key after attach -void GradKey::attach(Tensor* tensor, pybind11::object callback) { - if (!active) { - throw py::value_error("grad key finalized"); - } - - if (tensor->m_grad_info_dict.count(this)) { - if (tensor->m_grad_info_dict.at(this)->callback) { - throw py::value_error("callback already set on this tensor"); + GenericFunction generic_callback = + [=](Span inputs) -> std::vector { + mgb_assert(inputs.size() == 1); + if (callback) { + callback(TensorWrapper::make(py_tensor_type, inputs[0])); } - } else { - auto& grad_info = tensor->m_grad_info_dict[this]; - grad_info.idx = 0; - auto& grad_fn = grad_info.grad_fn; - grad_fn = std::make_shared(); - grad_fn->key = shared_from_this(); - grad_fn->slots.resize(1); - grad_info.insert_after(free_vars_head); - tensor->m_flags |= Flags::GRAD; - } - tensor->m_grad_info_dict.at(this).grad_fn->slots[0].callback = std::move(callback); -} - -template -void accum_grad(std::shared_ptr& grad, T&& delta) { - if (!grad) { - grad = std::forward(delta); - return; - } - static std::shared_ptr op = - std::shared_ptr(new Elemwise(Elemwise::Mode::ADD)); - grad = apply(op, grad, std::forward(delta))[0]; + return {}; + }; + tw->m_tensor->reset(imperative::apply( + AttachGrad(m_key), tw->m_tensor->data(), + FunctionValue::make(generic_callback))[0]); } -void GradKey::backward( - std::vector tensors, std::vector grads) { - if (!active) { - throw py::value_error("finalized"); +void GradKeyWrapper::backward(GradKeyWrapper* self, py::list tensors, py::list grads) { + std::vector args; + mgb_assert(tensors.size() == grads.size()); + for (auto&& tensor : tensors) { + args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data()); } - if (tensors.size() != grads.size()) { - throw py::value_error("tensor and grad size mismatch"); + for (auto&& grad : grads) { + args.push_back(TensorWrapper::try_cast(grad.ptr())->m_tensor->data()); } - - // this GradKey is marked inactive here - active = false; - struct CleanupGuard { - GradKey* owner; - size_t priority_backup; - CleanupGuard(GradKey* this_) : owner(this_) { - priority_backup = sm_min_priority; - sm_min_priority = owner->priority + 1; - } - ~CleanupGuard() { - owner->cleanup(); - sm_min_priority = priority_backup; - } - } _cleanup_guard(this); - - if (tape.empty()) - return; - - BackwardContext bctx; - if (!grads.empty()) { - bctx.pytype = Py_TYPE(grads[0]->self().ptr()); - } - - for (size_t i = 0; i < tensors.size(); ++i) { - if (tensors[i]->m_tensor->m_grad_info_dict.count(this) == 0) { - continue; - } - auto& grad_info = tensors[i]->m_tensor->m_grad_info_dict.at(this); - grad_info->grad = grads[i]->m_tensor; - } - - std::vector> ref_keeper; - ref_keeper.reserve(tape.size()); - - // back-propagation in reverse order - for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) { - auto&& grad_fn = tape[k].lock(); - if (!grad_fn) - continue; - - auto grad_receiver = [&](size_t i, auto&& g) { - auto& dst = grad_fn->dsts[i]; - if (dst) { - accum_grad(dst->grad, std::forward(g)); - } - }; - std::visit( - [&](auto&& backward) { - using T = std::decay_t; - if constexpr (std::is_same_v) { - mgb_assert(0); - } else { - auto&& grads = views::transform( - grad_fn->slots, - [](auto&& slot) { return slot.grad.get(); }); - backward( - bctx, std::forward(grads), - grad_receiver); - } - }, - grad_fn->backward); - - for (auto&& dst : grad_fn->dsts) { - if (!dst.grad_fn) - continue; - if (!dst.grad_fn->in_ref_keeper) { - // after grad_fn is cleared, refcnt of subsequent grad_fn - // could drop to 0 - dst.grad_fn->in_ref_keeper = true; - ref_keeper.push_back(dst.grad_fn); - } - if (!dst.producer_record.next && dst->callback && dst->grad) { - // I'm the last grad producer, invoke callback - dst->callback(bctx.wrap_tensor(dst->grad)); - } - } - grad_fn->clear(); - } // finish tape loop + imperative::apply(GradBackward(self->m_key), {args.data(), args.size()}); } -void GradKey::cleanup() { - active = false; - tape.clear(); - for (intrusive_list::Iterator it(free_vars_head); it;) { - it->grad_fn.reset(); - (it++)->unlink(); +pybind11::function GradKeyWrapper::get_backward_closure( + GradKeyWrapper* self, py::list tensors) { + std::vector args; + for (auto&& tensor : tensors) { + args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data()); } -} - -void GradKeyWrapper::backward( - std::vector tensors, std::vector grads) { - m_key->backward(std::move(tensors), std::move(grads)); + auto closure = imperative::apply(GetBackwardColsure(self->m_key), args)[0] + .as(); + auto py_function = [closure](std::vector tensors) { + std::vector args; + for (auto* tw : tensors) { + args.push_back(tw->m_tensor->data()); + } + (*closure)(args); + }; + return pybind11::cpp_function(py_function); } PyObject* GradKeyWrapper::get_name() { - return py::cast(m_key->name).release().ptr(); + return py::cast(m_key->name()).release().ptr(); } void GradKeyWrapper::set_name(py::handle name) { - m_key->name = py::cast(name); + m_key->name(py::cast(name)); } PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) { @@ -651,60 +106,39 @@ PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) { PyErr_SetString(PyExc_TypeError, "expect Tensor"); return nullptr; } - if (tw->m_tensor->m_grad_info_dict.count(m_key.get())) { + if (imperative::apply(IsAttachedTo(m_key), tw->m_tensor->data())[0] + .cast()) { Py_RETURN_TRUE; } Py_RETURN_FALSE; } -int GradKey::sm_min_priority = std::numeric_limits::min(); - -GradKey::~GradKey() { - cleanup(); +void GradKeyWrapper::enter() { + m_transformation = std::make_shared(m_key); + TransformationManager::get_instance().register_at( + m_transformation); } -std::unordered_map& grad_rule_registry() { - static std::unordered_map registry; - return registry; +void GradKeyWrapper::exit() { + TransformationManager::get_instance().unregister( + m_transformation); + m_transformation.reset(); } -void GradInfoCollection::_shrink() { - auto pred = [](GradInfo& info) { - return !(info.grad_fn) || info.grad_fn->key.expired(); - }; - auto iter = std::remove_if(m_storage.begin(), m_storage.end(), pred); - m_storage.erase(iter, m_storage.end()); +void GradKeyWrapper::suppress() { + m_transformation->suppress(); } -bool GradInfoCollection::contains(GradKey* key) { - _shrink(); - for (auto&& grad_info : m_storage) { - if (grad_info.grad_fn->key.lock().get() == key) { - return true; - } - } - return false; +void GradKeyWrapper::resume() { + m_transformation->resume(); } -GradInfo& GradInfoCollection::operator[](GradKey* key) { - _shrink(); - for (auto&& grad_info : m_storage) { - if (grad_info.grad_fn->key.lock().get() == key) { - return grad_info; - } - } - m_storage.emplace_back(); - return m_storage.back(); +GradKeyWrapper* GradKeyWrapper::get(std::shared_ptr key) { + return grad_key_map.at(key); } -GradInfo& GradInfoCollection::at(GradKey* key) { - _shrink(); - for (auto&& grad_info : m_storage) { - if (grad_info.grad_fn->key.lock().get() == key) { - return grad_info; - } - } - mgb_assert(false); +GradKeyWrapper::~GradKeyWrapper() { + grad_key_map.erase(m_key); } } // namespace mgb::imperative::python diff --git a/imperative/python/src/grad.h b/imperative/python/src/grad.h index 72faef1f..d5175aaf 100644 --- a/imperative/python/src/grad.h +++ b/imperative/python/src/grad.h @@ -12,166 +12,40 @@ #pragma once #include "./tensor.h" + #include "megbrain/imperative/ops/utility.h" +#include "megbrain/imperative/transformations/grad.h" +#include "megbrain/utils/small_vector.h" -#include #include #include namespace mgb::imperative::python { -apply_result_t apply_grad(ApplyContext& ctx); - -struct GradKey : std::enable_shared_from_this, NonCopyableObj { - std::string name; - bool active = true; - GradInfo::head_t free_vars_head; - std::vector> tape; - int priority = 0; - - ~GradKey(); - - void attach(Tensor* tensor, pybind11::object callback); - void backward(std::vector, std::vector); - void cleanup(); - bool is_blocked() const { return priority < sm_min_priority; } - inline static bool allow_higher_order_directive = false; - -private: - static int sm_min_priority; -}; - -struct GradKeyWrapper { +struct GradKeyWrapper : NonCopyableObj { using wrap_t = pyext17::wrap; static constexpr auto tp_name = pybind11::detail::_("GradKey"); std::shared_ptr m_key; + std::shared_ptr m_transformation; - inline GradKeyWrapper() : m_key(std::make_shared()) {} + GradKeyWrapper(); PyObject* get_name(); void set_name(pybind11::handle name); - PyObject* get_priority(); - void set_priority(pybind11::handle priority); void attach(PyObject* const* args, size_t nargs); - void backward(std::vector, std::vector); + static void backward(GradKeyWrapper* self, pybind11::list, pybind11::list); + static pybind11::function get_backward_closure( + GradKeyWrapper* self, pybind11::list); PyObject* is_attached_to(PyObject* const* args, size_t nargs); + void enter(); + void exit(); + void suppress(); + void resume(); + static GradKeyWrapper* get(std::shared_ptr key); + ~GradKeyWrapper(); }; -struct BackwardContext { - PyTypeObject* pytype = nullptr; - - auto wrap_tensor(std::shared_ptr t) { - if (pytype) { - return TensorWrapper::make(pytype, std::move(t)); - } - return TensorWrapper::make(std::move(t)); - } - - auto wrap_tensor(Tensor* t) { return wrap_tensor(t->shared_from_this()); } -}; - -struct CustomBackward { - using BackwardFn = - std::function; - BackwardFn m_backward; - SmallVector m_input_has_grad; - struct OutputAttr { - bool requires_grad = true, captured = true; - }; - SmallVector m_output_attrs; - -public: - template - void operator()(BackwardContext& ctx, T&& grads, R&& receiver) { - size_t nargs = grads.size(); - Tensor* args[nargs]; - for (size_t i = 0; i < nargs; ++i) { - args[i] = grads[i]; - } - auto ret = m_backward(ctx, args, nargs); - for (size_t i = 0; i < ret.size(); ++i) { - if (auto&& t = ret[i]) { - receiver(i, std::move(t)); - } - } - } - - bool input_has_grad(size_t i) { return m_input_has_grad[i]; } - bool output_requires_grad(size_t i) { return m_output_attrs[i].requires_grad; } - bool output_captured(size_t i) { return m_output_attrs[i].captured; } - - class Maker { - bool output_size_set = false, input_has_grad_initialized = false; - CustomBackward& target; - ApplyContext& ctx; - - void init_input_has_grad() { - if (!input_has_grad_initialized) { - input_has_grad_initialized = true; - target.m_input_has_grad.resize(ctx.nargs, true); - } - } - - public: - Maker(CustomBackward& target_, ApplyContext& ctx_) - : target(target_), ctx(ctx_) {} - - template - Maker& backward(F&& f) { - mgb_assert(!target.m_backward); - target.m_backward = std::forward(f); - return *this; - } - // mandatory - Maker& output_size(size_t sz) { - mgb_assert(!output_size_set); - output_size_set = true; - target.m_output_attrs.resize(sz); - return *this; - } - // optional, defaults to all true - Maker& input_has_grad(size_t i, bool v) { - init_input_has_grad(); - target.m_input_has_grad.at(i) = v; - return *this; - } - // optional, defaults to all true - Maker& output_requires_grad(size_t i, bool v) { - target.m_output_attrs.at(i).requires_grad = v; - return *this; - } - // optional, defaults to all true - Maker& output_captured(size_t i, bool v) { - target.m_output_attrs.at(i).captured = v; - return *this; - } - - void finalize() { - mgb_assert(output_size_set); - init_input_has_grad(); - } - }; - - Maker maker(ApplyContext& ctx) { return {*this, ctx}; } -}; - -using GradRuleFn = std::function( - ApplyContext&, CustomBackward::Maker&)>; - -std::unordered_map& grad_rule_registry(); - -inline bool input_requires_grad(const ApplyContext& ctx, size_t i) { - return !ctx.args[i]->m_grad_info_dict.empty(); -} - -struct GradRuleFallback : std::exception {}; - -template -bool register_grad_rule(Typeinfo* typeinfo, T&& rule) { - return grad_rule_registry().emplace(typeinfo, std::forward(rule)).second; -} - } // namespace mgb::imperative::python namespace pybind11::detail { diff --git a/imperative/python/src/grad_info.h b/imperative/python/src/grad_info.h deleted file mode 100644 index 1e765ef5..00000000 --- a/imperative/python/src/grad_info.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * \file imperative/python/src/grad_info.h - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - */ - -#include - -#include "./intrusive_list.h" - -namespace mgb::imperative::python { - -struct GradKey; -struct GradFn; -struct GradSlot; - -struct GradSlotPtr { - std::shared_ptr grad_fn; - size_t idx; - - operator bool() const; - GradSlot* operator->(); -}; - -struct GradInfo : GradSlotPtr, - intrusive_list::Node { - GradInfo() = default; - GradInfo(GradInfo&) = default; - GradInfo(GradInfo&&) = default; - GradInfo& operator=(GradInfo&) = default; - GradInfo& operator=(GradInfo&&) = default; - GradInfo(const GradInfo& rhs) : GradInfo(const_cast(rhs)) {} - GradInfo& operator=(const GradInfo& rhs) { - return *this = const_cast(rhs); - } -}; - -} // namespace mgb::imperative::python diff --git a/imperative/python/src/grad_override.cpp b/imperative/python/src/grad_override.cpp index c8d21596..5269113d 100644 --- a/imperative/python/src/grad_override.cpp +++ b/imperative/python/src/grad_override.cpp @@ -11,261 +11,334 @@ #include "./grad.h" #include "megbrain/imperative/ops/autogen.h" +#include "megbrain/imperative/transformations/grad.h" namespace mgb::imperative::python { + +class CustomGradMaker { + bool output_size_set = false, input_has_grad_initialized = false; + CustomBackward& target; + size_t nr_inputs; + void init_input_has_grad() { + if (!input_has_grad_initialized) { + input_has_grad_initialized = true; + target.m_input_has_grad.resize(nr_inputs, true); + } + } + +public: + CustomGradMaker(CustomBackward& target, size_t nr_inputs) + : target(target), nr_inputs(nr_inputs) {} + + CustomGradMaker& backward(CustomBackward::BackwardFn f) { + mgb_assert(!target.m_backward); + target.m_backward = f; + return *this; + } + // mandatory + CustomGradMaker& output_size(size_t sz) { + mgb_assert(!output_size_set); + output_size_set = true; + target.m_output_attrs.resize(sz); + return *this; + } + // optional, defaults to all true + CustomGradMaker& input_has_grad(size_t i, bool v) { + init_input_has_grad(); + target.m_input_has_grad.at(i) = v; + return *this; + } + // optional, defaults to all true + CustomGradMaker& output_requires_grad(size_t i, bool v) { + target.m_output_attrs.at(i).requires_grad = v; + return *this; + } + // optional, defaults to all true + CustomGradMaker& output_captured(size_t i, bool v) { + target.m_output_attrs.at(i).captured = v; + return *this; + } + void finalize() { + mgb_assert(output_size_set); + init_input_has_grad(); + } +}; + namespace { -std::shared_ptr get_shape(Tensor* x) { +ValueRef get_shape(ValueRef x) { static auto op = GetVarShape::make(); - return python::apply(op, x)[0]; + return imperative::apply(*op, x)[0]; } -std::shared_ptr reduce_to(Tensor* x, Tensor* s) { +ValueRef reduce_to(ValueRef x, ValueRef s) { static auto op = Reduce::make(); - return python::apply(op, x, s)[0]; + return imperative::apply(*op, x, s)[0]; } -std::shared_ptr reshape_to(Tensor* x, Tensor* s) { +ValueRef reshape_to(ValueRef x, ValueRef s) { static auto op = Reshape::make(); - return python::apply(op, x, s)[0]; + return imperative::apply(*op, x, s)[0]; } -std::shared_ptr broadcast_to(Tensor* x, Tensor* s) { +ValueRef broadcast_to(ValueRef x, ValueRef s) { static auto op = Broadcast::make(); - return python::apply(op, x, s)[0]; + return imperative::apply(*op, x, s)[0]; } -std::shared_ptr make_empty_tensor(CompNode cn, Tensor* shape, DType dtype) { - HostTensorND scalar{cn, {{1}, dtype}}; - std::memset(scalar.raw_ptr(), 0, dtype.size()); - interpreter::Interpreter::Handle handle = interpreter_for_py->put(scalar, false); - auto&& t = std::make_shared(handle); - auto res = broadcast_to(t.get(), shape); +ValueRef make_empty_tensor( + CompNodeValue::ref_t device, ValueRef shape, DTypeValue::ref_t dtype) { + HostTensorStorage storage(*device); + storage.ensure_size(dtype->size()); + std::memset(storage.ptr(), 0, dtype->size()); + auto t = imperative::apply( + CreateTensor(CreateTensor::Unique, *device, *dtype, ValueShape()), + HostStorage::make(storage))[0]; + auto res = broadcast_to(t, shape); return res; } -std::optional elemwise_grad_rule( - ApplyContext& ctx, CustomBackward::Maker& maker) { - auto& op = ctx.op->cast_final_safe(); - if (op.mode == Elemwise::Mode::ADD) { - mgb_assert(ctx.nargs == 2); - std::array, 2> input_shapes; +std::optional> elemwise_grad_rule( + const OpDef& op, Span inputs, Span inputs_require_grad, + CustomBackward& backward) { + auto& elemwise = op.cast_final_safe(); + if (elemwise.mode != Elemwise::Mode::ADD) { + return {}; + } + mgb_assert(inputs.size() == 2); + std::array input_shapes; + for (size_t i = 0; i < 2; ++i) { + if (inputs_require_grad[i]) { + input_shapes[i] = get_shape(inputs[i]); + } + } + auto maker = CustomGradMaker(backward, inputs.size()); + maker.output_size(1).output_captured(0, false); + maker.backward([shapes = std::move(input_shapes)](Span grads) { + mgb_assert(grads.size() == 1); + ValueRef grad = grads[0]; + std::vector ret(2); + if (!grad) { + return ret; + } for (size_t i = 0; i < 2; ++i) { - if (input_requires_grad(ctx, i)) { - input_shapes[i] = get_shape(ctx.args[i]); + if (shapes[i]) { + ret[i] = reduce_to(grad, shapes[i]); } } - maker.output_size(1).output_captured(0, false); - maker.backward([shapes = std::move(input_shapes)]( - BackwardContext&, Tensor* const* grads, size_t ngrads) { - mgb_assert(ngrads == 1); - Tensor* grad = grads[0]; - apply_result_t ret(2); - if (!grad) { - return ret; - } - for (size_t i = 0; i < 2; ++i) { - if (shapes[i]) { - ret[i] = reduce_to(grad, shapes[i].get()); - } - } - return ret; - }); - return apply(ctx); - } - return {}; + return ret; + }); + maker.finalize(); + return imperative::apply(ApplyOp(op), inputs); } -std::optional reshape_grad_rule( - ApplyContext& ctx, CustomBackward::Maker& maker) { - mgb_assert(ctx.nargs == 2); - std::array, 2> input_shapes; +std::optional> reshape_grad_rule( + const OpDef& op, Span inputs, Span inputs_require_grad, + CustomBackward& backward) { + mgb_assert(inputs.size() == 2); + std::array input_shapes; for (size_t i = 0; i < 2; ++i) { - if (input_requires_grad(ctx, i)) { - input_shapes[i] = get_shape(ctx.args[i]); + if (inputs_require_grad[i]) { + input_shapes[i] = get_shape(inputs[i]); } } + auto maker = CustomGradMaker(backward, inputs.size()); maker.output_size(1).output_captured(0, false); - maker.backward([shapes = std::move(input_shapes)]( - BackwardContext&, Tensor* const* grads, size_t ngrads) { - mgb_assert(ngrads == 1); - Tensor* grad = grads[0]; - apply_result_t ret(2); + maker.backward([shapes = std::move(input_shapes)](Span grads) { + mgb_assert(grads.size() == 1); + ValueRef grad = grads[0]; + std::vector ret(2); if (!grad) { return ret; } for (size_t i = 0; i < 2; ++i) { if (shapes[i]) { - ret[i] = reshape_to(grad, shapes[i].get()); + ret[i] = reshape_to(grad, shapes[i]); } } return ret; }); - return apply(ctx); + maker.finalize(); + return imperative::apply(ApplyOp(op), inputs); } -std::optional subtensor_grad_rule( - ApplyContext& ctx, CustomBackward::Maker& maker) { - auto&& op = ctx.op->cast_final_safe(); - auto&& grad_op = SetSubtensor::make(op.items); - SmallVector> inputs; - if (input_requires_grad(ctx, 0)) { - inputs.push_back(get_shape(ctx.args[0])); - for (size_t i = 1; i < ctx.nargs; ++i) { - inputs.push_back(ctx.args[i]->copy()); +std::optional> subtensor_grad_rule( + const OpDef& op, Span inputs, Span inputs_require_grad, + CustomBackward& backward) { + auto&& subtensor = op.cast_final_safe(); + auto&& grad_op = SetSubtensor::make(subtensor.items); + SmallVector inputs2; + if (inputs_require_grad[0]) { + inputs2.push_back(get_shape(inputs[0])); + for (size_t i = 1; i < inputs.size(); ++i) { + inputs2.push_back(inputs[i]); } } + auto maker = CustomGradMaker(backward, inputs.size()); maker.output_size(1).output_captured(0, false); - maker.backward([inputs = std::move(inputs), grad_op_ = std::move(grad_op)]( - BackwardContext&, Tensor* const* grads, size_t ngrads) { - mgb_assert(ngrads == 1); - Tensor* grad = grads[0]; - apply_result_t ret(1); + maker.backward([inputs = std::move(inputs2), + grad_op_ = std::move(grad_op)](Span grads) { + mgb_assert(grads.size() == 1); + ValueRef grad = grads[0]; + std::vector ret(1); if (grad && inputs[0]) { - SmallVector args_(inputs.size() + 1); - auto&& zeros = make_empty_tensor( - grad->comp_node(), inputs[0].get(), grad->dtype()); - args_[0] = zeros.get(); + SmallVector args_(inputs.size() + 1); + auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype()); + args_[0] = zeros; args_[1] = grad; for (size_t i = 1; i < inputs.size(); ++i) { - args_[i + 1] = inputs[i].get(); + args_[i + 1] = inputs[i]; } - ret[0] = python::apply(grad_op_, args_)[0]; + ret[0] = imperative::apply(ApplyOp(*grad_op_), args_)[0]; } return ret; }); - return apply(ctx); + maker.finalize(); + return imperative::apply(ApplyOp(op), inputs); } -std::optional indexingMultiAxisVec_grad_rule( - ApplyContext& ctx, CustomBackward::Maker& maker) { - auto&& op = ctx.op->cast_final_safe(); - auto&& grad_op = IndexingSetMultiAxisVec::make(op.items); - SmallVector> inputs; - if (input_requires_grad(ctx, 0)) { - inputs.push_back(get_shape(ctx.args[0])); - for (size_t i = 1; i < ctx.nargs; ++i) { - inputs.push_back(ctx.args[i]->copy()); +std::optional> indexingMultiAxisVec_grad_rule( + const OpDef& op, Span inputs, Span inputs_require_grad, + CustomBackward& backward) { + auto&& indexingMultiAxisVec = op.cast_final_safe(); + auto&& grad_op = IndexingSetMultiAxisVec::make(indexingMultiAxisVec.items); + SmallVector inputs2; + if (inputs_require_grad[0]) { + inputs2.push_back(get_shape(inputs[0])); + for (size_t i = 1; i < inputs.size(); ++i) { + inputs2.push_back(inputs[i]); } } + auto maker = CustomGradMaker(backward, inputs.size()); maker.output_size(1).output_captured(0, false); - maker.backward([inputs = std::move(inputs), grad_op_ = std::move(grad_op)]( - BackwardContext&, Tensor* const* grads, size_t ngrads) { - mgb_assert(ngrads == 1); - Tensor* grad = grads[0]; - apply_result_t ret(1); + maker.backward([inputs = std::move(inputs2), + grad_op_ = std::move(grad_op)](Span grads) { + mgb_assert(grads.size() == 1); + ValueRef grad = grads[0]; + std::vector ret(1); if (grad && inputs[0]) { - SmallVector args_(inputs.size() + 1); - auto&& zeros = make_empty_tensor( - grad->comp_node(), inputs[0].get(), grad->dtype()); - args_[0] = zeros.get(); + SmallVector args_(inputs.size() + 1); + auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype()); + args_[0] = zeros; args_[1] = grad; for (size_t i = 1; i < inputs.size(); ++i) { - args_[i + 1] = inputs[i].get(); + args_[i + 1] = inputs[i]; } - ret[0] = python::apply(grad_op_, args_)[0]; + ret[0] = imperative::apply(ApplyOp(*grad_op_), args_)[0]; } return ret; }); - return apply(ctx); + maker.finalize(); + return imperative::apply(ApplyOp(op), inputs); } -std::optional reduce_grad_rule( - ApplyContext& ctx, CustomBackward::Maker& maker) { - auto& op = ctx.op->cast_final_safe(); - if (op.mode == Reduce::Mode::SUM) { - if (ctx.nargs != 1) { - return {}; - } - std::array, 1> input_shapes; - if (input_requires_grad(ctx, 0)) { - input_shapes[0] = get_shape(ctx.args[0]); - } - maker.output_size(1).output_captured(0, false); - maker.backward([shapes = std::move(input_shapes)]( - BackwardContext&, Tensor* const* grads, size_t ngrads) { - mgb_assert(ngrads == 1); - Tensor* grad = grads[0]; - apply_result_t ret(1); - if (grad && shapes[0]) { - ret[0] = broadcast_to(grad, shapes[0].get()); - } - return ret; - }); - return apply(ctx); +std::optional> reduce_grad_rule( + const OpDef& op, Span inputs, Span inputs_require_grad, + CustomBackward& backward) { + auto& reduce = op.cast_final_safe(); + if (reduce.mode != Reduce::Mode::SUM) { + return {}; + } + if (inputs.size() != 1) { + return {}; + } + std::array input_shapes; + if (inputs_require_grad[0]) { + input_shapes[0] = get_shape(inputs[0]); } - return {}; + auto maker = CustomGradMaker(backward, inputs.size()); + maker.output_size(1).output_captured(0, false); + maker.backward([shapes = std::move(input_shapes)](Span grads) { + mgb_assert(grads.size() == 1); + ValueRef grad = grads[0]; + std::vector ret(1); + if (grad && shapes[0]) { + ret[0] = broadcast_to(grad, shapes[0]); + } + return ret; + }); + maker.finalize(); + return imperative::apply(ApplyOp(op), inputs); } -std::optional addAxis_grad_rule( - ApplyContext& ctx, CustomBackward::Maker& maker) { - auto&& op = ctx.op->cast_final_safe(); - mgb_assert(ctx.nargs == 1); - bool flag = input_requires_grad(ctx, 0); - auto&& grad_op = RemoveAxis::make(op.axis); +std::optional> addAxis_grad_rule( + const OpDef& op, Span inputs, Span inputs_require_grad, + CustomBackward& backward) { + auto&& addAxis = op.cast_final_safe(); + mgb_assert(inputs.size() == 1); + bool flag = inputs_require_grad[0]; + auto&& grad_op = RemoveAxis::make(addAxis.axis); std::sort(grad_op->axis.begin(), grad_op->axis.end(), std::greater()); + auto maker = CustomGradMaker(backward, inputs.size()); maker.output_size(1).output_captured(0, false); - maker.backward([grad_op_ = std::move(grad_op), flag_ = flag]( - BackwardContext&, Tensor* const* grads, size_t ngrads) { - mgb_assert(ngrads == 1); - Tensor* grad = grads[0]; - apply_result_t ret(1); + maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span grads) { + mgb_assert(grads.size() == 1); + ValueRef grad = grads[0]; + std::vector ret(1); if (grad && flag_) { - ret[0] = python::apply(grad_op_, grad)[0]; + ret[0] = imperative::apply(*grad_op_, grad)[0]; } return ret; }); - return apply(ctx); + maker.finalize(); + return imperative::apply(op, inputs); } -std::optional removeAxis_grad_rule( - ApplyContext& ctx, CustomBackward::Maker& maker) { - auto&& op = ctx.op->cast_final_safe(); - mgb_assert(ctx.nargs == 1); - bool flag = input_requires_grad(ctx, 0); - auto&& grad_op = AddAxis::make(op.axis); +std::optional> removeAxis_grad_rule( + const OpDef& op, Span inputs, Span inputs_require_grad, + CustomBackward& backward) { + auto&& removeAxis = op.cast_final_safe(); + mgb_assert(inputs.size() == 1); + bool flag = inputs_require_grad[0]; + auto&& grad_op = AddAxis::make(removeAxis.axis); std::sort(grad_op->axis.begin(), grad_op->axis.end()); + auto maker = CustomGradMaker(backward, inputs.size()); maker.output_size(1).output_captured(0, false); - maker.backward([grad_op_ = std::move(grad_op), flag_ = flag]( - BackwardContext&, Tensor* const* grads, size_t ngrads) { - mgb_assert(ngrads == 1); - Tensor* grad = grads[0]; - apply_result_t ret(1); + maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span grads) { + mgb_assert(grads.size() == 1); + ValueRef grad = grads[0]; + std::vector ret(1); if (grad && flag_) { - ret[0] = python::apply(grad_op_, grad)[0]; + ret[0] = imperative::apply(*grad_op_, grad)[0]; } return ret; }); - return apply(ctx); + maker.finalize(); + return imperative::apply(op, inputs); } -std::optional fastpathcopy_grad_rule( - ApplyContext& ctx, CustomBackward::Maker& maker) { - mgb_assert(ctx.nargs == 1); +std::optional> fastpathcopy_grad_rule( + const OpDef& op, Span inputs, Span inputs_require_grad, + CustomBackward& backward) { + mgb_assert(inputs.size() == 1); + auto maker = CustomGradMaker(backward, inputs.size()); maker.output_size(1).output_captured(0, false); - maker.backward([](BackwardContext&, Tensor* const* grads, size_t ngrads) { - mgb_assert(ngrads == 1); - Tensor* grad = grads[0]; - apply_result_t ret(1); + maker.backward([](Span grads) { + mgb_assert(grads.size() == 1); + ValueRef grad = grads[0]; + std::vector ret(1); if (grad) { - ret[0] = grad->shared_from_this(); + ret[0] = grad; } return ret; }); - return apply(ctx); + maker.finalize(); + return imperative::apply(op, inputs); } struct Init { Init() { - auto& reg = grad_rule_registry(); - reg.emplace(Elemwise::typeinfo(), elemwise_grad_rule); - reg.emplace(Reshape::typeinfo(), reshape_grad_rule); - reg.emplace(Subtensor::typeinfo(), subtensor_grad_rule); - reg.emplace(IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule); - reg.emplace(Reduce::typeinfo(), reduce_grad_rule); - reg.emplace(AddAxis::typeinfo(), addAxis_grad_rule); - reg.emplace(RemoveAxis::typeinfo(), removeAxis_grad_rule); - reg.emplace(FastpathCopy::typeinfo(), fastpathcopy_grad_rule); + CustomBackward::register_grad_rule(Elemwise::typeinfo(), elemwise_grad_rule); + CustomBackward::register_grad_rule(Reshape::typeinfo(), reshape_grad_rule); + CustomBackward::register_grad_rule(Subtensor::typeinfo(), subtensor_grad_rule); + CustomBackward::register_grad_rule( + IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule); + CustomBackward::register_grad_rule(Reduce::typeinfo(), reduce_grad_rule); + CustomBackward::register_grad_rule(AddAxis::typeinfo(), addAxis_grad_rule); + CustomBackward::register_grad_rule( + RemoveAxis::typeinfo(), removeAxis_grad_rule); + CustomBackward::register_grad_rule( + FastpathCopy::typeinfo(), fastpathcopy_grad_rule); } } _; diff --git a/imperative/python/src/intrusive_list.h b/imperative/python/src/intrusive_list.h deleted file mode 100644 index 04c8896b..00000000 --- a/imperative/python/src/intrusive_list.h +++ /dev/null @@ -1,245 +0,0 @@ -/** - * \file imperative/python/src/intrusive_list.h - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - */ - -#include "megbrain/utils/metahelper.h" - -namespace mgb::imperative::python::intrusive_list { - -// copy policy -struct after_t {}; -struct before_t {}; -struct disable_t {}; - -template -struct Tail; - -// invariant: next->prev == this -template -struct Head { - Tail* next; - - Head(Tail* node = nullptr) : next(node) {} - Head(const Head&) = delete; - Head& operator=(const Head&) = delete; - Head(Head&& rhs) : next(rhs.next) { - rhs.next = nullptr; - if (next) { - next->prev = this; - } - } - Head& operator=(Head&& rhs) { - mgb_assert(!next); - next = rhs.next; - rhs.next = nullptr; - if (next) { - next->prev = this; - } - return *this; - } - ~Head() { - if (next) { - next->prev = nullptr; - } - } -}; - -// invariant: prev->next == this -template -struct Tail { - Head* prev; - - Tail(Head* node = nullptr) : prev(node) {} - Tail(const Tail&) = delete; - Tail& operator=(const Tail&) = delete; - Tail(Tail&& rhs) : prev(rhs.prev) { - rhs.prev = nullptr; - if (prev) { - prev->next = this; - } - } - Tail& operator=(Tail&& rhs) { - mgb_assert(!prev); - prev = rhs.prev; - rhs.prev = nullptr; - if (prev) { - prev->next = this; - } - return *this; - } - ~Tail() { - if (prev) { - prev->next = nullptr; - } - } -}; - -template -struct Node; - -template -class Iterator { - T* ptr; - - void inc() { ptr = static_cast(ptr->Head::next); } - void dec() { ptr = static_cast(ptr->Head::prev); } - -public: - Iterator(Head& head) : ptr(static_cast(head.next)) {} - Iterator(Tail& tail) : ptr(static_cast(tail.prev)) {} - - template - Iterator(Node& node) : ptr(static_cast(&node)) {} - - T& operator*() { return *static_cast(ptr); } - T* operator->() { return static_cast(ptr); } - - operator bool() { return ptr; } - bool operator==(const Iterator& rhs) { return ptr == rhs.ptr; } - - Iterator& operator++() { - inc(); - return *this; - } - Iterator& operator--() { - dec(); - return *this; - } - Iterator operator++(int) { - auto ret = *this; - inc(); - return ret; - } - Iterator operator--(int) { - auto ret = *this; - dec(); - return ret; - } -}; - -// Node in a doubly linked list. Unlike std::list, nodes are not owned by a container. -// Instead, nodes may join or leave a list freely. -// NOTE: Derived classes have to explicitly declare copy / assignment as default, -// otherwise the compiler generated version would use the const T& signature, -// which is deleted. -template -struct Node : Tail, Node, T>>, - Head, Node, T>> { -private: - using this_t = Node; - using U = std::conditional_t, this_t, T>; - -public: - using head_t = Head; - using tail_t = Tail; - using head_t::next; - using tail_t::prev; - - Node() = default; - Node(const this_t&) = delete; - this_t& operator=(const this_t&) = delete; - - //! constructed node is inserted after the input node - Node(after_t, head_t& node) : tail_t(&node), head_t(node.next) { - node.next = this; - if (next) { - next->prev = this; - } - } - - //! constructed node is inserted before the input node - Node(before_t, tail_t& node) : head_t(&node), tail_t(node.prev) { - node.prev = this; - if (prev) { - prev->next = this; - } - } - - Node(this_t&& rhs) : tail_t(rhs.prev), head_t(rhs.next) { - rhs.prev = nullptr; - rhs.next = nullptr; - if (prev) { - prev->next = this; - } - if (next) { - next->prev = this; - } - } - - Node& operator=(this_t&& rhs) { - unlink(); - prev = rhs.prev; - next = rhs.next; - rhs.prev = nullptr; - rhs.next = nullptr; - if (prev) { - prev->next = this; - } - if (next) { - next->prev = this; - } - return *this; - } - - template < - typename p = policy, - typename = std::enable_if_t< - std::is_same_v || std::is_same_v, void>> - Node(this_t& rhs) : Node(policy{}, rhs) {} - - template < - typename p = policy, - typename = std::enable_if_t< - std::is_same_v || std::is_same_v, void>> - this_t& operator=(this_t& rhs) { - insert(policy{}, rhs); - return *this; - } - - void unlink() { - if (prev) { - prev->next = next; - } - if (next) { - next->prev = prev; - } - prev = nullptr; - next = nullptr; - } - - //! this node is unlinked from its list and inserted after the input node - void insert(after_t, head_t& node) { - unlink(); - prev = &node; - next = node.next; - node.next = this; - if (next) { - next->prev = this; - } - } - - //! this node is unlinked from its list and inserted before the input node - void insert(before_t, tail_t& node) { - unlink(); - next = &node; - prev = node.prev; - node.prev = this; - if (prev) { - prev->next = this; - } - } - - void insert_before(tail_t& node) { insert(before_t{}, node); } - void insert_after(head_t& node) { insert(after_t{}, node); } - - ~Node() { unlink(); } -}; - -} // namespace mgb::imperative::python::intrusive_list diff --git a/imperative/python/src/module_trace.cpp b/imperative/python/src/module_trace.cpp deleted file mode 100644 index 702944cf..00000000 --- a/imperative/python/src/module_trace.cpp +++ /dev/null @@ -1,42 +0,0 @@ -/** - * \file imperative/python/src/module_trace.cpp - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ - -#include "./module_trace.h" -#include "./helper.h" // include op pybind11 caster - -namespace py = pybind11; - -namespace mgb::imperative::python { - -apply_result_t apply_module_trace(ApplyContext& ctx) { - apply_result_t outputs; - - auto args = py::tuple(ctx.nargs + 1); - args[0] = py::cast(ctx.op); - for (size_t i = 0; i < ctx.nargs; i++) { - args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this()); - } - auto pyout = PyObject_Call(cpp_apply_module_trace, args.ptr(), nullptr); - if (!pyout) - throw py::error_already_set(); - auto ret = py::reinterpret_steal(pyout); - - // assumption: python function always returns PyList - auto tup = py::reinterpret_borrow(ret); - for (auto i = 0; i < tup.size(); i++) { - auto tw = TensorWrapper::try_cast(tup[i].ptr()); - outputs.emplace_back(tw->m_tensor); - } - return outputs; -} - -} // namespace mgb::imperative::python diff --git a/imperative/python/src/module_trace.h b/imperative/python/src/module_trace.h index bd4d47f4..410900ec 100644 --- a/imperative/python/src/module_trace.h +++ b/imperative/python/src/module_trace.h @@ -11,10 +11,50 @@ #pragma once +#include "megbrain/imperative/transformations/trace.h" +#include "megbrain/imperative/utils/map.h" + #include "./tensor.h" namespace mgb::imperative::python { -apply_result_t apply_module_trace(ApplyContext& ctx); +namespace py = pybind11; + +class ModuleTraceTransformation final : public Transformation { +private: + py::function m_hook_fn; + int m_enabled = 0; + + std::vector apply_module_trace_hook( + const OpDef& op, Span input_values) { + py::list input_tws; + for (auto&& input_value : input_values) { + input_tws.append(TensorWrapper::make(py_tensor_type, input_value)); + } + py::list output_tws = m_hook_fn(py::cast(op.shared_from_this()), *input_tws); + std::vector outputs; + for (auto&& output_tw : output_tws) { + outputs.push_back( + TensorWrapper::try_cast(output_tw.ptr())->m_tensor->data()); + } + return outputs; + } + +public: + ModuleTraceTransformation(py::function hook_fn) : m_hook_fn(hook_fn) {} + std::vector apply_transformation( + const Operator& op, Span inputs) override { + if (op.is() && m_enabled > 0) { + auto outputs = apply_module_trace_hook(op.cast().op(), inputs); + return outputs; + } else { + return imperative::apply(op, inputs); + } + } + + ValueRef unwrap(ValueRef value) override { return value; } + + std::string name() const override { return "ModuleTraceTransformation"; } +}; } // namespace mgb::imperative::python diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index be184cff..a53b1027 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -185,7 +185,8 @@ int py_set_scope(PyObject* obj, PyObject* value, void* /* closure */) { } PyGetSetDef PyOp(OpDef)::py_getsetters[] = { - {const_cast("scope"), py_get_scope, py_set_scope, "scope", NULL}, + {const_cast("scope"), py_get_scope, py_set_scope, + const_cast("scope"), NULL}, {NULL}}; Py_hash_t PyOp(OpDef)::tp_hash(PyObject* obj) { @@ -556,12 +557,6 @@ void init_ops(py::module m) { m.def( "delete_rng_handle", [](size_t handle) { - // RNG op might execute after handle released due to async dispatch, so - // we need sync before delete a handle to avoid memory leak or - // use-after-free - if (python::interpreter_for_py->check_available()) { - python::interpreter_for_py->sync(); - } mgb::CompNode::sync_all(); py_task_q.wait_all_task_finish(); rng::delete_handle(handle); diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index f415f572..af0f8cae 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -15,7 +15,14 @@ #include "megbrain/imperative/ops/backward_graph.h" #include "megbrain/imperative/ops/utility.h" #include "megbrain/imperative/profiler.h" +#include "megbrain/imperative/transformations/eval.h" +#include "megbrain/imperative/transformations/lazy.h" +#include "megbrain/imperative/transformations/scalar.h" +#include "megbrain/imperative/transformations/symbol.h" +#include "megbrain/imperative/transformations/trace.h" +#include "megbrain/imperative/utils/map.h" #include "megbrain/opr/io.h" +#include "megbrain/plugin/profiler.h" #include "./common.h" #include "./grad.h" @@ -24,7 +31,7 @@ #include "./module_trace.h" #include "./numpy_dtypes.h" #include "./tensor.h" -#include "./trace.h" +#include "./transformation.h" #include #include @@ -36,143 +43,19 @@ #include +#include "../../src/impl/mgb_cg_impl.h" + namespace py = pybind11; namespace views = ranges::views; namespace mgb::imperative::python { -interpreter::Interpreter::Channel* interpreter_for_py; - -PyObject *cpp_apply_with_tracing, *cpp_apply_const_with_tracing; -PyObject* cpp_apply_backward_varnode; -PyObject* cpp_apply_module_trace; - -std::shared_ptr make_const(imperative::TensorPtr value) { - if (!(ApplyContext::global_enable & Tensor::Flags::TRACE)) { - return std::make_shared( - interpreter_for_py->put(value->dev_tensor(), value->get_value())); - } - py::tuple tup(6); - auto data = value->get_value(); - tup[0] = py::reinterpret_steal( - ndarray_from_tensor(data, npy::ShareType::MUST_SHARE)); - tup[1] = value->dtype(); - tup[2] = value->comp_node(); - tup[3] = true; - tup[4] = false; - tup[5] = py::none{}; - auto py_ret = PyObject_Call(cpp_apply_const_with_tracing, tup.ptr(), nullptr); - if (!py_ret) - throw py::error_already_set(); - auto py_list = py::reinterpret_steal(py_ret); - auto* tensor_wrapper = TensorWrapper::try_cast(py_list[0].ptr()); - auto tensor = tensor_wrapper->m_tensor; - return tensor_wrapper->m_tensor; -} - -#define REGISTE_APPLY_FUNC(mode) \ - void set_##mode(py::object pyf) { mode = pyf.ptr(); } - -REGISTE_APPLY_FUNC(cpp_apply_with_tracing) -REGISTE_APPLY_FUNC(cpp_apply_const_with_tracing) -REGISTE_APPLY_FUNC(cpp_apply_backward_varnode) -REGISTE_APPLY_FUNC(cpp_apply_module_trace) - -#undef REGISTE_APPLY_FUNC - -Tensor::flags_t ApplyContext::global_disable = 0; -Tensor::flags_t ApplyContext::global_enable = 0; - -void set_tracing() { - ApplyContext::global_enable |= Tensor::Flags::TRACE; -} -void unset_tracing() { - ApplyContext::global_enable &= ~Tensor::Flags::TRACE; -} - -void set_module_tracing() { - ApplyContext::global_enable |= Tensor::Flags::MODULE_TRACE; -} -void unset_module_tracing() { - ApplyContext::global_enable &= ~Tensor::Flags::MODULE_TRACE; +namespace { +WeakKeyMap module_trace_info_map; } -bool is_tracing_module() { - return ApplyContext::global_enable & Tensor::Flags::MODULE_TRACE; -} - -bool skip_tracing = false; - -apply_result_t apply(ApplyContext& ctx) { - // emulating scalar should be put to specific op's apply, e.g., - // elementwise, reduce, typecvt. Currently it's still handled at python - // side. It could be move to C++ side if it has an impact on performance - auto flags = ctx.flags & ~ApplyContext::global_disable; - flags = flags | ApplyContext::global_enable; - - if (flags & Tensor::Flags::SCALAR) { - // TODO: emulate scalar - } - if (flags & Tensor::Flags::GRAD) { - return apply_grad(ctx); - } - - if (auto* op = ctx.op->try_cast_final()) { - py::tuple pyin(ctx.nargs); - for (size_t i = 0; i < ctx.nargs; ++i) { - pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this()); - } - auto f = py::getattr(op->obj, "_default_rule"); - auto pyout = py::reinterpret_steal( - PyObject_Call(f.ptr(), pyin.ptr(), nullptr)); - if (!pyout) - throw py::error_already_set(); - if (auto* tw = TensorWrapper::try_cast(pyout.ptr())) { - return {tw->m_tensor}; - } - apply_result_t ret; - ret.reserve(py::len(pyout)); - for (auto&& i : pyout) { - auto* tw = TensorWrapper::try_cast(i.ptr()); - mgb_assert(tw); - ret.push_back(tw->m_tensor); - } - return ret; - } - - if (flags & Tensor::Flags::MODULE_TRACE) { - return apply_module_trace(ctx); - } - - if (flags & Tensor::Flags::TRACE) { - return apply_trace(ctx); - } else { - SmallVector handles(ctx.nargs); - for (size_t i = 0; i < ctx.nargs; ++i) { - handles[i] = ctx.args[i]->m_handle.get(); - } - - apply_result_t outputs; - - // fast copy without really applying - if (ctx.op->same_type()) { - mgb_assert(ctx.nargs == 1); - outputs.reserve(ctx.nargs); - outputs.emplace_back(std::make_shared(ctx.args[0]->m_handle)); - return outputs; - } - - auto output_handles = interpreter_for_py->apply_op(ctx.op, handles); - - outputs.reserve(output_handles.size()); - for (auto h : output_handles) { - outputs.emplace_back(std::make_shared(h)); - } - return outputs; - } - - mgb_assert(0); -} +interpreter::Interpreter::Channel* interpreter_for_py = nullptr; +PyTypeObject* py_tensor_type = nullptr; PyObject* py_apply( PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */) { @@ -189,62 +72,70 @@ PyObject* py_apply( return nullptr; } - auto* op = args[0]; - - PyTypeObject* pytype = args[1]->ob_type; + auto* py_op = args[0]; - // check if pytype is Parameter(and all other python Tensor's derived class), - // if yes, using it's tp_base(python Tensor) - if (TensorWrapper::wrap_t::type().same_pytype(pytype->tp_base->tp_base)) { - pytype = pytype->tp_base; - } ++args; --nargs; - ApplyContext ctx; - ctx.flags = 0; - ctx.op = py::handle(op).cast>(); - SmallVector tensors(nargs); - ctx.args = &tensors[0]; - ctx.nargs = nargs; - ctx.pytype = pytype; + auto op = py::handle(py_op).cast>(); + SmallVector tensors(nargs); if (py::isinstance(py::handle(args[0]))) { - SmallVector vinputs(nargs); + // swap to a special context to reuse scalar handle + TransformationContext symbol_var_context; + Transformation::swap_context(symbol_var_context); + CleanupGuard _{[&] { Transformation::swap_context(symbol_var_context); }}; + auto* graph = + py::handle(args[0]).cast()->m_node->owner_graph(); + std::make_shared(graph)->register_at( + Transformation::top()); + std::make_shared()->register_at( + Transformation::top()); + SmallVector inputs(nargs); for (size_t i = 0; i < nargs; ++i) { - vinputs[i] = py::handle(args[i]).cast()->m_node; + auto* py_input = py::handle(args[i]).cast(); + ValueRef input = SymbolValue::make(py_input->m_node); + if (py_input->is_scalar) { + input = ScalarValue::make(input); + } + inputs[i] = input; } - auto op = ctx.op.get(); - auto rst = OpDef::apply_on_var_node(*op, vinputs); - auto ret = pybind11::tuple(rst.size()); + auto outputs = imperative::apply(*op, inputs); + auto ret = pybind11::tuple(outputs.size()); auto typeobj = py::handle(args[0]).get_type(); - for (size_t i = 0; i < rst.size(); ++i) { - ret[i] = typeobj(pybind11::cast( - rst[i], pybind11::return_value_policy::automatic)); + for (size_t i = 0; i < outputs.size(); ++i) { + bool is_scalar = false; + if (auto* scalar_value = outputs[i].as()) { + outputs[i] = scalar_value->value(); + is_scalar = true; + } + auto* node = outputs[i].cast().node(); + ret[i] = typeobj( + pybind11::cast(node, pybind11::return_value_policy::automatic)); + py::handle(ret[i]).cast()->is_scalar = is_scalar; } return ret.release().ptr(); } for (size_t i = 0; i < nargs; ++i) { if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { - auto* t = tensors[i] = tw->m_tensor.get(); - ctx.flags |= t->m_flags; + tensors[i] = tw->m_tensor->data(); } else { PyErr_SetString( PyExc_TypeError, ssprintf( "op %s expect type Tensor as inputs, got %s actually", - ctx.op->make_name().c_str(), Py_TYPE(args[i])->tp_name) + op->make_name().c_str(), Py_TYPE(args[i])->tp_name) .c_str()); return nullptr; } } - auto outputs = apply(ctx); + auto outputs = imperative::apply(ApplyOp(*op), {tensors.data(), nargs}); size_t nout = outputs.size(); auto ret = py::tuple(nout); for (size_t i = 0; i < nout; ++i) { - ret[i] = TensorWrapper::make(pytype, std::move(outputs[i])); + ret[i] = TensorWrapper::make(py_tensor_type, std::move(outputs[i])); } return ret.release().ptr(); } @@ -264,28 +155,19 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { if (nargs > 1) { throw py::type_error("expect 1 argument"); } - m_tensor = t->m_tensor; + m_tensor = t->m_tensor->copy(); } else { if (nargs == 1) { auto arg0 = PyTuple_GetItem(args, 0); - // for lazy_eval_tensor - if (strstr(arg0->ob_type->tp_name, "VarNode")) { - if (PyObject_HasAttrString(arg0, "_node")) { - arg0 = PyObject_GetAttrString(arg0, "_node"); - } - m_tensor = - std::make_shared(py::handle(arg0).cast()); + // for DeviceTensorND + if (strstr(arg0->ob_type->tp_name, "DeviceTensorND")) { + auto dv = py::handle(arg0).cast(); + m_tensor = std::make_shared(imperative::apply( + CreateTensor(CreateTensor::Common, dv.comp_node(), dv.layout()), + DeviceStorage::make(dv.storage()))[0]); } else { - // for DeviceTensorND - if (strstr(arg0->ob_type->tp_name, "DeviceTensorND")) { - auto dv = py::handle(arg0).cast(); - interpreter::Interpreter::Handle handle = - interpreter_for_py->put(dv, {}); - m_tensor = std::make_shared(handle); - } else { - throw py::type_error( - "single argument is not tensor, varnode or devicetensor"); - } + throw py::type_error( + "single argument is not tensor, varnode or devicetensor"); } } else { py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType @@ -302,329 +184,140 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { name = tup[nargs - 1].cast(); // const op - if (is_const && (ApplyContext::global_enable == Tensor::Flags::TRACE)) { - auto py_ret = - PyObject_Call(cpp_apply_const_with_tracing, tup.ptr(), nullptr); - if (!py_ret) - throw py::error_already_set(); - auto py_list = py::reinterpret_steal(py_ret); - if (auto* t = try_cast(py_list[0].ptr())) { - m_tensor = t->m_tensor; - } - return; - } - - interpreter::Interpreter::Handle handle; { + CreateTensor::Kind kind = is_const ? CreateTensor::Const + : no_cache ? CreateTensor::Unique + : CreateTensor::Common; HostTensorND ret(cn); - handle = interpreter_for_py->put( - npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype), - no_cache); + ret = npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype); + mgb_assert( + ret.layout().is_empty() || ret.layout().is_contiguous(), + "host value should be continuous"); + ValueShape shape; + for (size_t i = 0; i < data.ndim(); ++i) { + shape[shape.ndim++] = data.shape(i); + } + m_tensor = std::make_shared(imperative::apply( + CreateTensor(kind, cn, ret.dtype(), shape), + HostStorage::make(ret.storage()))[0]); } - m_tensor = std::make_shared(handle); - m_tensor->user_custom_name = name; + if (!name.empty()) { + m_tensor->reset( + imperative::apply(RenameValue(name), m_tensor->data())[0]); + mgb_assert( + ((std::string&)*m_tensor->data().name()) == name, + "result name incorrect"); + } if (data.ndim() == 0) { - m_tensor->m_flags |= Tensor::Flags::SCALAR; + mgb_assert(m_tensor->is_scalar(), "result should be scalar"); } } } } -#define REGISTE_TENSORWRAPPER_FUNC(type, member) \ - PyObject* TensorWrapper::member() { \ - return py::cast(m_tensor->m_trace_info.member).release().ptr(); \ - } \ - void TensorWrapper::set_##member(PyObject* dest) { \ - auto py_dest = py::reinterpret_borrow(dest); \ - type real_dest = py_dest.cast(); \ - m_tensor->m_trace_info.member = real_dest; \ - } - -REGISTE_TENSORWRAPPER_FUNC(int64_t, mixin_handle) -REGISTE_TENSORWRAPPER_FUNC(bool, recording) - -#undef REGISTE_TENSORWRAPPER_FUNC - PyObject* TensorWrapper::module_trace_info() { - if (!m_tensor->m_module_trace_info.ptr()) { + if (auto module_trace_info = module_trace_info_map.try_get(m_tensor->data())) { + return module_trace_info->inc_ref().ptr(); + } else { PyErr_SetString( PyExc_AttributeError, "Has no attribute named \'_NodeMixin__node\', please " "set it first"); return nullptr; } - return m_tensor->m_module_trace_info.inc_ref().ptr(); } void TensorWrapper::set_module_trace_info(PyObject* obj) { - m_tensor->m_module_trace_info = py::reinterpret_borrow(obj); + module_trace_info_map[m_tensor->data()] = py::reinterpret_borrow(obj); } -#define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \ - PyObject* TensorWrapper::member() { \ - if (m_tensor->m_trace_info.member) { \ - return m_tensor->m_trace_info.member; \ - } else { \ - Py_RETURN_NONE; \ - } \ - } \ - void TensorWrapper::set_##member(PyObject* dest) { \ - if (dest == Py_None) { \ - Py_XDECREF(m_tensor->m_trace_info.member); \ - m_tensor->m_trace_info.member = nullptr; \ - } else { \ - Py_INCREF(dest); \ - m_tensor->m_trace_info.member = dest; \ - } \ - } - -REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(compiled_info) -REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(trace_mixin_info) - -#undef REGISTE_TENSORWRAPPER_PYOBJECT_FUNC - -#define SET_GET_NAME(member) \ - PyObject* TensorWrapper::member() { \ - return py::cast(m_tensor->member).release().ptr(); \ - } \ - void TensorWrapper::set_##member(PyObject* dest) { \ - auto py_dest = py::reinterpret_borrow(dest); \ - m_tensor->member = py_dest.cast(); \ - } -SET_GET_NAME(user_custom_name) -SET_GET_NAME(automatic_name) -#undef SET_GET_NAME +void TensorWrapper::_set_name(PyObject* dest) { + auto py_dest = py::reinterpret_borrow(dest); + auto name = py_dest.cast(); + m_tensor->set_name(name); +} -PyObject* TensorWrapper::handle() { - return py::cast(m_tensor->m_handle).release().ptr(); +PyObject* TensorWrapper::_detail() { + return py::str(m_tensor->data().unwrap().to_string()).release().ptr(); } -void TensorWrapper::set_handle(PyObject* dest) { - auto py_dest = py::reinterpret_borrow(dest); - SharedHandle real_dest = py_dest.cast(); - m_tensor->m_handle = std::move(real_dest); +void TensorWrapper::_watch() { + m_tensor->data().watch(); } PyObject* TensorWrapper::shape() { - // if it's tracing compiled mode, get value from compiled_info - if (m_tensor->m_trace_info.compiled_info != nullptr) { - if (m_tensor->m_flags & Tensor::Flags::SCALAR) { - return PyTuple_New(0); - } - PyObject* shp = - PyObject_GetAttrString(m_tensor->m_trace_info.compiled_info, "shape"); - if (shp == Py_None) { - throw TraceReadError("shape of this tensor is not read in trace"); - } - return shp; - } - - // inside trace, if tensor shape is useful for other operations, set shape_read = true - if (m_tensor->m_trace_info.recording && !skip_tracing) { - PyObject_SetAttrString( - m_tensor->m_trace_info.trace_mixin_info, "shape_read", - py::cast(true).release().ptr()); - } - - if (m_tensor->m_flags & Tensor::Flags::SCALAR) { - return PyTuple_New(0); - } - - TensorShape shape; - if (m_tensor->m_var) { // get shape from m_var - auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager(); - auto&& type = mgr.get_infer_type(m_tensor->m_var); - using InferType = cg::static_infer::InferType; - if (!(type.shape & (InferType::CONST | InferType::RT_STATIC))) { - Py_RETURN_NONE; - } - auto* tshp = mgr.infer_shape_fallible(m_tensor->m_var); - if (!tshp) { - Py_RETURN_NONE; - } - shape = *tshp; - } else { - py::gil_scoped_release _; - shape = m_tensor->shape(); - } + auto shape = m_tensor->shape(); - if (!shape.ndim) { + if (!shape) { Py_RETURN_NONE; } - py::tuple ret(shape.ndim); - for (size_t i = 0; i < shape.ndim; ++i) { - ret[i] = shape[i]; + py::tuple ret(shape->ndim); + for (size_t i = 0; i < shape->ndim; ++i) { + ret[i] = shape->at(i); } return ret.release().ptr(); } PyObject* TensorWrapper::dtype() { - if (m_tensor->m_var) { - return py::cast(m_tensor->m_var->dtype()).release().ptr(); - } return py::cast(m_tensor->dtype()).release().ptr(); } PyObject* TensorWrapper::device() { - if (m_tensor->m_var) { - return py::cast(m_tensor->m_var->comp_node()).release().ptr(); - } return py::cast(m_tensor->comp_node()).release().ptr(); } PyObject* TensorWrapper::numpy() { - if (m_tensor->m_trace_info.compiled_info != nullptr) { - PyObject* np_val = PyObject_CallMethod( - m_tensor->m_trace_info.compiled_info, "numpy", nullptr); - if (!np_val) - throw py::error_already_set(); - if (np_val == Py_None) { - throw TraceReadError("value of this tensor is not read in trace"); - } - if (m_tensor->m_flags & Tensor::Flags::SCALAR) { - PyObject* np_scalar = - PyArray_Squeeze(reinterpret_cast(np_val)); - Py_DECREF(np_val); - return np_scalar; - } - return np_val; - } - - if (m_tensor->m_trace_info.recording && !skip_tracing) { - PyObject_SetAttrString( - m_tensor->m_trace_info.trace_mixin_info, "value_read", - py::cast(true).release().ptr()); - } - - if (m_tensor->m_handle.get() == nullptr && m_tensor->m_var != nullptr) { - auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager(); - auto&& type = mgr.get_infer_type(m_tensor->m_var); - using InferType = cg::static_infer::InferType; - if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) { - PyErr_SetString(PyExc_ValueError, "tensor invalid"); - return nullptr; - } - auto* val = mgr.infer_value_fallible(m_tensor->m_var); - if (!val) { - PyErr_SetString(PyExc_ValueError, "tensor invalid"); - return nullptr; - } - auto np_val = py::cast(*val).attr("numpy")(); - if (m_tensor->m_flags & Tensor::Flags::SCALAR) { - return PyArray_Squeeze( - reinterpret_cast(np_val.release().ptr())); - } - return np_val.release().ptr(); - } - auto&& hv = [&]() { - py::gil_scoped_release _; - return interpreter_for_py->get_value(m_tensor->m_handle.get()); - }(); + auto hv = m_tensor->numpy(); + // if (!hv) { + // PyErr_SetString(PyExc_ValueError, "tensor invalid"); + // return nullptr; + // } auto arr = py::reinterpret_steal( - npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE)); + npy::ndarray_from_tensor(hv->as_nd(true), npy::ShareType::TRY_SHARE)); if (!arr) { PyErr_SetString(PyExc_ValueError, "tensor invalid"); return nullptr; } - - if (m_tensor->m_flags & Tensor::Flags::SCALAR) { + if (hv->shape().is_scalar()) { mgb_assert(PyArray_Check(arr.ptr())); return PyArray_Squeeze(reinterpret_cast(arr.ptr())); } return arr.release().ptr(); } -PyObject* TensorWrapper::varnode() { - if (m_tensor->m_var) { - return py::cast(m_tensor->m_var).release().ptr(); - } - Py_RETURN_NONE; -} - void TensorWrapper::reset(PyObject* tensor) { TensorWrapper* t = TensorWrapper::try_cast(tensor); if (!t) { throw py::type_error("expect Tensor"); } - std::string user_custom_name = m_tensor->user_custom_name; - std::string automatic_name = m_tensor->automatic_name; - auto module_trace_info = m_tensor->m_module_trace_info; - m_tensor = t->m_tensor; - m_tensor->m_module_trace_info = module_trace_info; - m_tensor->user_custom_name = user_custom_name; - m_tensor->automatic_name = automatic_name; -} - -void TensorWrapper::reset_varnode() { - m_tensor->m_var = nullptr; + m_tensor->reset(t->m_tensor->data()); } PyObject* TensorWrapper::detach() { - PyObject* self = wrap_t::pycast(this); - PyTypeObject* pytype = self->ob_type; - - static std::shared_ptr op = std::shared_ptr(new FastpathCopy()); - auto new_tensor = python::apply(op, m_tensor)[0]; - new_tensor->m_grad_info_dict = {}; - auto ret = TensorWrapper::make(pytype, std::move(new_tensor)); - return ret.release().ptr(); + auto detached = imperative::apply(DetachGrad(), m_tensor->data())[0]; + return TensorWrapper::make(py_tensor_type, detached).release().ptr(); } PyObject* TensorWrapper::_dev_tensor() { - if (m_tensor->m_trace_info.compiled_info != nullptr) { - auto* dev_tensor = PyObject_CallMethod( - m_tensor->m_trace_info.compiled_info, "_dev_tensor", nullptr); - if (!dev_tensor) - throw py::error_already_set(); - if (dev_tensor == Py_None) { - throw TraceReadError("raw data of this tensor is not read in trace"); - } - - // set m_handle to make it a real tensor - auto py_dev_tensor = py::reinterpret_borrow(dev_tensor); - auto sh = interpreter_for_py->put(py_dev_tensor.cast(), {}); - m_tensor->m_handle = std::move(SharedHandle(sh)); - - // compiled info is useless after m_handle is set - Py_DECREF(m_tensor->m_trace_info.compiled_info); - m_tensor->m_trace_info.compiled_info = nullptr; - - return dev_tensor; - } - if (m_tensor->m_trace_info.recording && !skip_tracing) { - PyObject_SetAttrString( - m_tensor->m_trace_info.trace_mixin_info, "data_read", - py::cast(true).release().ptr()); - } - auto dev_tensor = [&]() { - py::gil_scoped_release _; - return interpreter_for_py->get_dev_tensor(m_tensor->m_handle.get()); - }(); - return py::cast(dev_tensor).release().ptr(); + auto dv = m_tensor->data().dev_tensor(); + // TODO: handle scalar + return py::cast(dv->as_nd(true)).release().ptr(); } void TensorWrapper::_drop() { - interpreter_for_py->drop(m_tensor->m_handle.get()); + imperative::apply(DTRCommand(DTRCommand::Drop), m_tensor->data()); } PyObject* TensorWrapper::isscalar() { - if (m_tensor->m_flags & Tensor::Flags::SCALAR) { + if (m_tensor->is_scalar()) { Py_RETURN_TRUE; } else { Py_RETURN_FALSE; } } -void TensorWrapper::setscalar() { - m_tensor->m_flags |= Tensor::Flags::SCALAR; -} - -void TensorWrapper::unsetscalar() { - m_tensor->m_flags &= ~Tensor::Flags::SCALAR; -} - struct TensorWeakRef { std::weak_ptr wptr; @@ -632,7 +325,7 @@ struct TensorWeakRef { py::object operator()() { if (auto p = wptr.lock()) { - return TensorWrapper::make(p); + return TensorWrapper::make(py_tensor_type, p); } return py::none(); } @@ -875,9 +568,18 @@ WRAP_FUNC_PY35(get_device); void init_tensor(py::module m) { imperative::Tensor::static_initialize(); - static auto sl_interpreter_for_py = - interpreter::Interpreter::inst().create_channel(); - interpreter_for_py = sl_interpreter_for_py.get(); + + static auto& transformations = TransformationManager::get_instance(); + + using Segment = TransformationManager::Segment; + + auto* channel = interpreter::Interpreter::inst().create_channel().release(); + interpreter_for_py = channel; + transformations.register_at( + std::make_shared( + std::unique_ptr(channel))); + transformations.register_at( + std::make_shared()); static py::exception py_async_error( m, "AsyncError", PyExc_RuntimeError); @@ -919,34 +621,14 @@ void init_tensor(py::module m) { .def_getset<&TensorWrapper::device>("device") .def<&TensorWrapper::reset>("_reset") .def<&TensorWrapper::isscalar>("_isscalar") - .def<&TensorWrapper::setscalar>("_setscalar") - .def<&TensorWrapper::unsetscalar>("_unsetscalar") .def<&TensorWrapper::detach>("detach") + // TODO: remove this .def<&TensorWrapper::_dev_tensor>("_dev_tensor") .def<&TensorWrapper::_drop>("_drop") - .def<&TensorWrapper::reset_varnode>("_reset_varnode") .def<&TensorWrapper::_use_cnt>("_use_cnt") - .def_getset<&TensorWrapper::varnode>("_varnode") - .def_getset< - &TensorWrapper::mixin_handle, - &TensorWrapper::set_mixin_handle>("_mixin_handle") - .def_getset< - &TensorWrapper::recording, &TensorWrapper::set_recording>( - "_recording") - .def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>( - "_handle") - .def_getset< - &TensorWrapper::compiled_info, - &TensorWrapper::set_compiled_info>("_compiled_info") - .def_getset< - &TensorWrapper::trace_mixin_info, - &TensorWrapper::set_trace_mixin_info>("_trace_mixin_info") - .def_getset< - &TensorWrapper::user_custom_name, - &TensorWrapper::set_user_custom_name>("c_name") - .def_getset< - &TensorWrapper::automatic_name, - &TensorWrapper::set_automatic_name>("_name") + .def<&TensorWrapper::_detail>("_detail") + .def<&TensorWrapper::_set_name>("_set_name") + .def<&TensorWrapper::_watch>("_watch") .def_getset< &TensorWrapper::module_trace_info, &TensorWrapper::set_module_trace_info>("_NodeMixin__node") @@ -989,13 +671,9 @@ void init_tensor(py::module m) { throw py::value_error("value invalid!"); } auto np_val = py::cast(*val).attr("numpy")(); - if (v->is_scalar) { - return py::object(py::array(np_val).squeeze()); - } return np_val; }) .def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; }) - .def("_setscalar", [](PySymbolVar* v) { return v->is_scalar = true; }) .def(py::init([](cg::VarNode* node) { return std::make_shared(node); }), @@ -1015,79 +693,75 @@ void init_tensor(py::module m) { } } - static constexpr auto sync_py_task_q = [] { py_task_q.wait_all_task_finish(); }; + static constexpr auto sync_py_task_q = [] { + py::gil_scoped_release _; + py_task_q.wait_all_task_finish(); + }; - m.def("set_option", [](std::string name, size_t value) { - interpreter_for_py->set_option(name, value); + m.def("set_option", [channel](std::string name, size_t value) { + channel->set_option(name, value); }); - m.def("clear_candidates", []() { interpreter_for_py->clear_candidates(); }); + m.def("clear_candidates", [channel]() { channel->clear_candidates(); }); m.def("get_option", - [](std::string name) { return interpreter_for_py->get_option(name); }); + [channel](std::string name) { return channel->get_option(name); }); m.def("_set_drop_flag", - [](bool flag) { interpreter_for_py->set_option("enable_drop", flag); }); - m.def("config_async_level", [](int level) { + [channel](bool flag) { channel->set_option("enable_drop", flag); }); + m.def("config_async_level", [channel](int level) { mgb_assert(level >= 0 and level <= 2, "async_level should be 0, 1 or 2"); - interpreter_for_py->set_option("async_level", level); + channel->set_option("async_level", level); }); m.def("get_async_level", - []() { return interpreter_for_py->get_option("async_level"); }); - m.def("set_buffer_length", [](int length) { + [channel]() { return channel->get_option("async_level"); }); + m.def("set_buffer_length", [channel](int length) { mgb_assert(length >= 0 and length < 100, "buffer_length should be in [0, 100)"); - interpreter_for_py->set_option("buffer_length", length); + channel->set_option("buffer_length", length); + }); + m.def("push_scope", [channel](std::string name) { + Transformation::push_scope(name); + channel->push_scope(name); + }); + m.def("pop_scope", [channel](std::string name) { + channel->pop_scope(name); + Transformation::pop_scope(name); + }); + m.def("start_profile", [channel](imperative::Profiler::options_t options) { + channel->sync(); + imperative::Profiler::load_options(std::move(options)); + imperative::Profiler::start_profile(); + channel->start_profile(); + }); + m.def("stop_profile", [channel]() -> std::function { + channel->stop_profile(); + channel->sync(); + imperative::Profiler::stop_profile(); + auto results = std::make_shared( + imperative::Profiler::collect()); + return [results = results](std::string basename, std::string format) mutable { + imperative::Profiler::dump_profile(basename, format, std::move(*results)); + results = nullptr; + }; + }); + m.def("sync", [channel]() { + if (channel->check_available()) { + channel->sync(); + } + sync_py_task_q(); + }); + m.def("full_sync", [channel]() { + if (channel->check_available()) { + channel->sync(); + } + CompNode::sync_all(); + CompNode::foreach ([](CompNode cn) { + auto err = cn.check_async_error(); + mgb_assert(!err, "%s", err->what()); + }); + sync_py_task_q(); + }); + m.def("close", [channel]() { + channel->close(); + sync_py_task_q(); }); - m.def("push_scope", [](std::string name) { interpreter_for_py->push_scope(name); }); - m.def("pop_scope", [](std::string name) { interpreter_for_py->pop_scope(name); }); - m.def( - "start_profile", - [](imperative::Profiler::options_t options) { - interpreter_for_py->sync(); - imperative::Profiler::load_options(std::move(options)); - imperative::Profiler::start_profile(); - interpreter_for_py->start_profile(); - }, - py::call_guard()); - m.def( - "stop_profile", - []() -> std::function { - interpreter_for_py->stop_profile(); - interpreter_for_py->sync(); - imperative::Profiler::stop_profile(); - auto results = std::make_shared( - imperative::Profiler::collect()); - return [results = results]( - std::string basename, std::string format) mutable { - imperative::Profiler::dump_profile( - basename, format, std::move(*results)); - results = nullptr; - }; - }, - py::call_guard()); - m.def( - "sync", - []() { - interpreter_for_py->sync(); - sync_py_task_q(); - }, - py::call_guard()); - m.def( - "full_sync", - []() { - interpreter_for_py->sync(); - CompNode::sync_all(); - CompNode::foreach ([](CompNode cn) { - auto err = cn.check_async_error(); - mgb_assert(!err, "%s", err->what()); - }); - sync_py_task_q(); - }, - py::call_guard()); - m.def( - "close", - []() { - interpreter_for_py->close(); - sync_py_task_q(); - }, - py::call_guard()); py::handle grad_key_type = GradKeyWrapper::wrap_t::type() @@ -1095,37 +769,315 @@ void init_tensor(py::module m) { .def<&GradKeyWrapper::is_attached_to>("is_attached_to") .def_getset<&GradKeyWrapper::get_name, &GradKeyWrapper::set_name>( "name") - .def_getset< - &GradKeyWrapper::get_priority, - &GradKeyWrapper::set_priority>("priority") + .def<&GradKeyWrapper::enter>("enter") + .def<&GradKeyWrapper::exit>("exit") + .def<&GradKeyWrapper::suppress>("suppress") + .def<&GradKeyWrapper::resume>("resume") .finalize(); if (!grad_key_type) throw py::error_already_set(); py::setattr(m, "GradKey", grad_key_type); m.def("backward", &GradKeyWrapper::backward); + m.def("get_backward_closure", &GradKeyWrapper::get_backward_closure); - m.def("set_cpp_apply_with_tracing", &set_cpp_apply_with_tracing); - m.def("set_cpp_apply_const_with_tracing", &set_cpp_apply_const_with_tracing); - m.def("set_cpp_apply_backward_varnode", &set_cpp_apply_backward_varnode); - m.def("set_cpp_apply_module_trace", &set_cpp_apply_module_trace); - m.attr("skip_tracing") = &skip_tracing; - - py::class_(m, "SharedHandle") - .def(py::init()) - .def("__eq__", - [](SharedHandle& thish, SharedHandle& thath) { - return (thish.get() == thath.get()); + m.def("set_py_tensor_type", [](py::object type_obj) { + py_tensor_type = reinterpret_cast(type_obj.inc_ref().ptr()); + }); + + /** + * \brief trace proxy + * + */ + struct Trace { + bool symbolic = false; + bool no_exec = false; + bool capture_as_const = false; + bool profile = false; + bool record_input_shapes = false; + py::function options_visitor; + std::shared_ptr tracing; + std::shared_ptr compiled; + std::shared_ptr lazy_eval; + std::pair> profiler; + std::optional trace_result; + std::function array_comparator; + + bool compare_value(ValueRef lhs, ValueRef rhs) { + if (!lhs.shape()->eq(*rhs.shape())) { + return false; + } + HostTensorND lvalue = lhs.numpy()->as_nd(true); + HostTensorND rvalue = rhs.numpy()->as_nd(true); + auto larr = py::reinterpret_steal( + npy::ndarray_from_tensor(lvalue, npy::ShareType::TRY_SHARE)); + auto rarr = py::reinterpret_steal( + npy::ndarray_from_tensor(rvalue, npy::ShareType::TRY_SHARE)); + return array_comparator(larr, rarr); + } + + void enter() { + auto& self = *this; + if (!self.trace_result) { // untraced + self.tracing = std::make_shared( + self.capture_as_const, self.record_input_shapes); + if (self.symbolic) { + self.lazy_eval = + std::make_shared(self.no_exec); + self.options_visitor(py::cast(&self.lazy_eval->options())); + } + } else if (!self.compiled) { // traced but not compiled + using namespace std::placeholders; + self.compiled = std::make_shared( + *self.trace_result, self.record_input_shapes); + self.compiled->set_value_comparator( + std::bind(&Trace::compare_value, this, _1, _2)); + self.options_visitor(py::cast(&self.compiled->options())); + self.compiled->compile(); + } + // register transformations + if (self.compiled) { + if (self.profile) { + auto& current_graph = self.compiled->graph(); + if (self.profiler.first != self.compiled->graph().id()) { + // graph changed + self.profiler = std::make_pair( + current_graph.id(), + std::make_shared(¤t_graph)); + } + } + transformations.register_at(self.compiled); + // start execute because InputCallback depends + self.compiled->execute(); + } else if (self.tracing) { + transformations.register_at(self.tracing); + if (self.lazy_eval) { + transformations.register_at(self.lazy_eval); + } + } else { + mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled"); + } + } + + void exit() { + auto& self = *this; + if (self.tracing) { + transformations.unregister(self.tracing); + self.trace_result = self.tracing->get_result(); + self.tracing.reset(); + if (self.lazy_eval) { + auto lazy_eval = std::move(self.lazy_eval); + transformations.unregister(lazy_eval); + lazy_eval->check_exception(); + } + } else if (self.compiled) { + transformations.unregister(self.compiled); + self.compiled->wait(); + } else { + mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled"); + } + } + + VarNodeArray dump( + std::shared_ptr graph, + std::vector> inputs, + std::vector> outputs, + bool prefer_input_names) { + auto& self = *this; + mgb_assert(self.trace_result); + // mark is like "arg_0", "kwarg_xxx", "output_0" ... + std::unordered_map mark2var; + for (size_t i = 0; i < self.trace_result->vars.size(); ++i) { + auto& name = self.trace_result->vars[i].mark; + if (!name.empty()) { + mark2var[name] = i; + } + } + std::vector> input_vars; + std::vector> output_vars; + for (auto&& [input_mark, input_name, input_shape] : inputs) { + mgb_assert(input_shape.ndim, "input shape invalid"); + input_vars.push_back( + {mark2var.at(input_mark), input_name, input_shape}); + } + for (auto&& [output_name, repr] : outputs) { + output_vars.push_back({mark2var.at(output_name), repr}); + } + self.options_visitor(py::cast(&graph->options())); + auto vars = self.trace_result->dump( + *graph, input_vars, output_vars, prefer_input_names); + return vars; + } + }; + + py::class_(m, "Trace") + .def(py::init<>()) + .def_readwrite("record_input_shapes", &Trace::record_input_shapes) + .def_readwrite("array_comparator", &Trace::array_comparator) + .def_readwrite("profile", &Trace::profile) + .def_property_readonly( + "options", + [](Trace& self) { + if (self.compiled) { + return &self.compiled->options(); + } else { + return (ComputingGraph::Options*)nullptr; + } + }) + .def("get_profile", + [](Trace& self) -> py::object { + if (self.profiler.second && self.compiled) { + auto json = self.profiler.second->to_json_full( + self.compiled->graph().current_comp_seq()); + return py::str(json->to_string()); + } else { + return py::none(); + } + }) + .def_readwrite("symbolic", &Trace::symbolic) + .def_readwrite("capture_as_const", &Trace::capture_as_const) + .def_readwrite("no_exec", &Trace::no_exec) + .def_readwrite("options_visitor", &Trace::options_visitor) + .def("enter", &Trace::enter) + .def("exit", &Trace::exit) + .def("dump", &Trace::dump) + .def("begin_excluded_region", + [](Trace& self) { + mgb_assert(bool(self.tracing) ^ bool(self.compiled)); + if (self.tracing) { + transformations.unregister(self.tracing); + } else if (self.compiled) { + transformations.unregister(self.compiled); + } }) - .def("__hash__", - [](SharedHandle& sh) { return reinterpret_cast(sh.get()); }); - - m.def("set_tracing", &set_tracing); - m.def("unset_tracing", &unset_tracing); - m.def("set_allow_higher_order_directive", - [](bool value) { GradKey::allow_higher_order_directive = value; }); - m.def("set_module_tracing", &set_module_tracing); - m.def("unset_module_tracing", &unset_module_tracing); - m.def("is_tracing_module", &is_tracing_module); + .def("end_excluded_region", [](Trace& self) { + mgb_assert(bool(self.tracing) ^ bool(self.compiled)); + if (self.tracing) { + transformations.register_at(self.tracing); + } else if (self.compiled) { + transformations.register_at(self.compiled); + } + }); + + m.def("name_tensor", [](std::string name, py::object tensor) { + auto* tw = TensorWrapper::try_cast(tensor.ptr()); + auto output = imperative::apply(TraceMarkVar(name), tw->m_tensor->data())[0]; + tw->m_tensor->reset(output); + }); + + m.def("is_grad_attached", [](std::vector tensors) -> bool { + SmallVector values; + for (auto&& tensor : tensors) { + values.push_back(tensor.cast().m_tensor->data()); + } + auto outputs = imperative::apply(GetGradKey(), values); + if (outputs[0].is()) { + return true; + } else { + return false; + } + }); + + m.def("get_grad_key", [](std::vector tensors) -> py::object { + SmallVector values; + for (auto&& tensor : tensors) { + values.push_back(tensor.cast().m_tensor->data()); + } + auto outputs = imperative::apply(GetGradKey(), values); + if (auto* grad_key_val = outputs[0].as()) { + return py::reinterpret_borrow( + GradKeyWrapper::wrap_t::pycast(GradKeyWrapper::get(*grad_key_val))); + } else { + return py::none(); + } + }); + + m.def("set_grad", [](py::object py_key, py::function backward_fn, + std::vector inputs, + std::vector outputs) { + mgb_assert(GradKeyWrapper::wrap_t::type().isinstance(py_key.ptr())); + auto* key = reinterpret_cast(py_key.ptr())->inst(); + GenericFunction generic_backward_fn = + [backward_fn](Span output_grads) -> std::vector { + py::list output_grad_tws; + for (auto&& output_grad : output_grads) { + if (output_grad) { + output_grad_tws.append( + TensorWrapper::make(py_tensor_type, output_grad)); + } else { + output_grad_tws.append(py::none()); + } + } + py::tuple input_grad_tws = backward_fn(*output_grad_tws); + std::vector input_grads; + for (auto&& input_grad_tw : input_grad_tws) { + if (!input_grad_tw.is_none()) { + input_grads.push_back( + py::cast(input_grad_tw).m_tensor->data()); + } else { + input_grads.push_back({}); + } + } + return input_grads; + }; + SmallVector values; + for (auto&& input : inputs) { + values.push_back(input.cast().m_tensor->data()); + } + for (auto&& output : outputs) { + values.push_back(output.cast().m_tensor->data()); + } + auto wrapped_output_values = imperative::apply( + SetGrad(key->m_key, generic_backward_fn, inputs.size()), values); + std::vector wrapped_outputs; + mgb_assert(wrapped_output_values.size() == outputs.size()); + for (auto&& output_value : wrapped_output_values) { + wrapped_outputs.push_back( + TensorWrapper::make(py_tensor_type, output_value)); + } + return wrapped_outputs; + }); + + static py::function module_trace_hook; + + static std::shared_ptr module_trace_transformation; + static int module_tracing = 0; + + m.def("set_module_tracing", [=] { + if (!module_trace_transformation) { + mgb_assert(module_trace_hook); + module_trace_transformation = + std::make_shared(module_trace_hook); + } + if (++module_tracing == 1) { + transformations.register_at( + module_trace_transformation); + } + }); + + m.def("unset_module_tracing", [=] { + if (--module_tracing == 0) { + transformations.unregister( + module_trace_transformation); + } + }); + + m.def("is_tracing_module", [=] { return module_tracing > 0; }); + + m.def("set_module_trace_hook", + [](py::function function) { module_trace_hook = function; }); + + m.def("begin_record_values", [] { Value::begin_record_values(); }); + + m.def("end_record_values", [] { + std::vector> reprs; + auto values = Value::end_record_values(); + for (auto&& value : values) { + reprs.push_back({value.id(), value.to_string()}); + } + return reprs; + }); + + py::register_exception(m, "TraceError"); } #undef MGE_PY_INTERFACE diff --git a/imperative/python/src/tensor.h b/imperative/python/src/tensor.h index 7caf8fb1..c438aea6 100644 --- a/imperative/python/src/tensor.h +++ b/imperative/python/src/tensor.h @@ -20,6 +20,8 @@ #include "pybind11/pybind11.h" #include "./pyext17.h" +#include "megbrain/imperative/dispatch.h" +#include "megbrain/imperative/utils/span.h" namespace mgb::imperative::python { @@ -32,126 +34,67 @@ struct ObjectPtr : B { } // namespace mgb::imperative::python -#include "./grad_info.h" // for struct GradInfo -#include "./trace_info.h" // for struct TraceInfo - namespace mgb::imperative::python { -struct GradKey; - extern interpreter::Interpreter::Channel* interpreter_for_py; +extern PyTypeObject* py_tensor_type; -class SharedHandle { - using Handle = interpreter::Interpreter::Handle; - static_assert(std::is_pointer_v); - std::shared_ptr> holder; - -public: - inline explicit SharedHandle(Handle handle) - : holder(handle, [](auto* h) { - if (h) { - interpreter_for_py->del(h); - } - }) {} - SharedHandle(const SharedHandle&) = default; - SharedHandle& operator=(const SharedHandle&) = default; - SharedHandle(SharedHandle&&) = default; - SharedHandle& operator=(SharedHandle&&) = default; - - inline Handle get() { return holder.get(); } -}; - -// impl in grad.cpp -class GradInfoCollection { +struct Tensor : std::enable_shared_from_this, NonCopyableObj { private: - SmallVector m_storage; - -protected: - void _shrink(); + std::string m_name; + ValueRef m_data; public: - bool contains(GradKey* key); - GradInfo& operator[](GradKey* key); - GradInfo& at(GradKey* key); - bool empty() { - _shrink(); - return m_storage.empty(); - } - auto begin() { - _shrink(); - return m_storage.begin(); - } - auto end() { - _shrink(); - return m_storage.end(); - } - size_t count(GradKey* key) { return contains(key) ? 1 : 0; } -}; - -struct Tensor : std::enable_shared_from_this, NonCopyableObj { - using flags_t = uint64_t; - - struct Flags { - static constexpr flags_t SCALAR = 1; - static constexpr flags_t GRAD = 1 << 1; - static constexpr flags_t TRACE = 1 << 2; - static constexpr flags_t MODULE_TRACE = 1 << 3; - }; - - flags_t m_flags = 0; - - GradInfoCollection m_grad_info_dict; - TraceInfo m_trace_info; - SharedHandle m_handle; - std::string user_custom_name; - std::string automatic_name; - cg::VarNode* m_var; - pybind11::object m_module_trace_info; - using Handle = interpreter::Interpreter::Handle; - inline Tensor() : m_handle(nullptr), m_var(nullptr) {} - inline explicit Tensor(Handle handle) : m_handle(handle), m_var(nullptr) {} - inline explicit Tensor(SharedHandle handle) - : m_handle(std::move(handle)), m_var(nullptr) {} - inline explicit Tensor(cg::VarNode* var) : m_handle(nullptr), m_var(var) {} + inline explicit Tensor(ValueRef data) : m_data{data} {} ~Tensor() = default; inline std::shared_ptr copy() { - auto ret = std::make_shared(m_handle); - ret->m_flags = m_flags; - ret->m_grad_info_dict = m_grad_info_dict; - ret->m_trace_info = m_trace_info; - ret->m_var = m_var; + auto ret = std::make_shared(m_data.unwrap()); + ret->m_name = m_name; return ret; } - inline DType dtype() { - if (m_var) { - return m_var->dtype(); + inline DType dtype() { return *data().dtype(); } + inline CompNode comp_node() { return *data().device(); } + inline std::optional shape() { + auto shape = data().shape(); + if (!shape) { + return {}; } - return interpreter_for_py->get_dtype(m_handle.get()); + return *shape; } - inline CompNode comp_node() { - if (m_var) { - return m_var->comp_node(); + inline HostValue::ref_t numpy() { return data().numpy(); } + inline void reset(ValueRef value) { + m_data = value; + if (!m_name.empty()) { + set_name(m_name); } - return interpreter_for_py->get_device(m_handle.get()); } - inline TensorShape shape() { - if (m_var) { - return m_var->shape(); + inline ValueRef data() { return m_data.unwrap(); } + bool is_scalar() { return data().is_scalar(); } + inline std::string name() { return m_name; } + inline void set_name(std::string name) { + m_name = name; + if (!name.empty()) { + auto output = imperative::apply(RenameValue(name), m_data)[0]; + m_data = output; } - return interpreter_for_py->get_shape(m_handle.get()); } }; struct TensorWrapper { +public: std::shared_ptr m_tensor; inline TensorWrapper(std::shared_ptr tensor = {}) - : m_tensor(std::move(tensor)) {} + : m_tensor(std::move(tensor)) { + mgb_assert(tensor, "empty storage"); + } + + inline TensorWrapper(ValueRef value) : m_tensor(std::make_shared(value)) {} TensorWrapper(PyObject* args, PyObject* kwargs); ~TensorWrapper() = default; @@ -191,33 +134,17 @@ struct TensorWrapper { void reset(PyObject*); PyObject* detach(); PyObject* isscalar(); - void setscalar(); - void unsetscalar(); PyObject* _dev_tensor(); void _drop(); PyObject* varnode(); - void reset_varnode(); - PyObject* handle(); - void set_handle(PyObject*); - - PyObject* mixin_handle(); PyObject* recording(); PyObject* copied(); - - void set_mixin_handle(PyObject*); - void set_recording(PyObject*); - - PyObject* compiled_info(); - void set_compiled_info(PyObject*); - PyObject* trace_mixin_info(); - void set_trace_mixin_info(PyObject*); PyObject* module_trace_info(); void set_module_trace_info(PyObject*); - PyObject* user_custom_name(); - void set_user_custom_name(PyObject*); - PyObject* automatic_name(); - void set_automatic_name(PyObject*); + void _set_name(PyObject*); PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); }; + PyObject* _detail(); + void _watch(); }; struct PySymbolVar { @@ -230,113 +157,8 @@ struct PySymbolVar { PyObject* py_apply( PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */); -struct ApplyContext { - static Tensor::flags_t global_disable; - static Tensor::flags_t global_enable; - - Tensor::flags_t flags = 0; - std::shared_ptr op; - Tensor* const* args; - size_t nargs; - PyTypeObject* pytype = nullptr; - bool backward = false; - - class scoped_disable : NonCopyableObj { - Tensor::flags_t saved_flags; - - public: - scoped_disable(Tensor::flags_t flags) - : saved_flags(ApplyContext::global_disable) { - ApplyContext::global_disable |= flags; - } - ~scoped_disable() { ApplyContext::global_disable = saved_flags; } - }; -}; - -using apply_result_t = SmallVector, 8>; - -apply_result_t apply(ApplyContext& ctx); - -template -decltype(auto) resolve_arrow(T&& p) { - if constexpr (std::is_pointer_v>) { - auto* ret = p; - return ret; - } else { - auto probe = [](auto&& p) -> decltype(p.operator->()) {}; - if constexpr (std::is_invocable_v) { - return resolve_arrow(p.operator->()); - } else { - return std::forward(p); - } - } -} - -template -constexpr bool is_all_tensor_ptr = - (... && std::is_same_v())), Tensor*>); - -template , int> = 0> -apply_result_t apply(std::shared_ptr op, Args&&... args) { - ApplyContext ctx; - Tensor* arg_arr[] = {resolve_arrow(args)...}; - ctx.flags = (0 | ... | args->m_flags); - ctx.args = arg_arr; - ctx.nargs = sizeof...(args); - ctx.op = std::move(op); - return apply(ctx); -} - -inline auto apply(std::shared_ptr op, Tensor* const* args, size_t nargs) { - ApplyContext ctx; - ctx.op = std::move(op); - ctx.nargs = nargs; - ctx.args = args; - for (size_t i = 0; i < nargs; ++i) { - ctx.flags |= args[i]->m_flags; - } - return apply(ctx); -} - -template -auto apply(std::shared_ptr op, T&& tensors) -> std::enable_if_t< - std::is_same_v, apply_result_t> { - size_t nargs = tensors.size(); - Tensor* args[nargs]; - for (size_t i = 0; i < nargs; ++i) { - args[i] = resolve_arrow(tensors[i]); - } - return apply(op, args, nargs); -} - -std::shared_ptr make_const(imperative::TensorPtr value); - -inline auto apply(Subgraph graph, Tensor* const* args, size_t nargs) { - SmallVector> inputs; - for (size_t i = 0; i < nargs; ++i) { - inputs.push_back(args[i]->shared_from_this()); - } - auto apply_functor = [](std::shared_ptr op, - SmallVector> inputs, - size_t) { return apply(op, std::move(inputs)); }; - return graph.apply(inputs, apply_functor, &make_const); -} - -template -auto apply(Subgraph graph, T&& tensors) -> std::enable_if_t< - std::is_same_v, Tensor*>, apply_result_t> { - size_t nargs = tensors.size(); - Tensor* args[nargs]; - for (size_t i = 0; i < nargs; ++i) { - args[i] = resolve_arrow(tensors[i]); - } - return apply(graph, args, nargs); -} - void init_tensor(pybind11::module); -extern PyObject* cpp_apply_with_tracing; -extern PyObject* cpp_apply_backward_varnode; extern PyObject* cpp_apply_module_trace; } // namespace mgb::imperative::python diff --git a/imperative/python/src/trace.cpp b/imperative/python/src/trace.cpp deleted file mode 100644 index 8a3d9696..00000000 --- a/imperative/python/src/trace.cpp +++ /dev/null @@ -1,63 +0,0 @@ -/** - * \file imperative/python/src/trace.cpp - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ - -#include "./trace.h" -#include "./helper.h" -#include "megbrain/imperative/ops/autogen.h" - -namespace py = pybind11; - -namespace mgb::imperative::python { - -apply_result_t apply_trace(ApplyContext& ctx) { - apply_result_t outputs; - - if (ctx.backward) { - // reach here when compiled=True - auto args = py::tuple(ctx.nargs + 1); - args[0] = py::cast(ctx.op); - for (size_t i = 0; i < ctx.nargs; i++) { - args[i + 1] = py::cast(ctx.args[i]->m_var); - } - py::object pyout = py::reinterpret_steal( - PyObject_Call(cpp_apply_backward_varnode, args.ptr(), nullptr)); - if (!pyout) - throw py::error_already_set(); - - // assumption: python function always returns PyList - auto tup = py::reinterpret_borrow(pyout); - for (size_t i = 0; i < tup.size(); i++) { - auto pitem = tup[i].cast(); - outputs.emplace_back(std::make_shared(pitem)); - } - return outputs; - } - - auto args = py::tuple(ctx.nargs + 1); - args[0] = py::cast(ctx.op); - for (size_t i = 0; i < ctx.nargs; i++) { - args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this()); - } - auto pyout = PyObject_Call(cpp_apply_with_tracing, args.ptr(), nullptr); - if (!pyout) - throw py::error_already_set(); - - // assumption: python function always returns PyList - auto tup = py::reinterpret_steal(pyout); - for (size_t i = 0; i < tup.size(); i++) { - auto tw = TensorWrapper::try_cast(tup[i].ptr()); - outputs.emplace_back(tw->m_tensor); - } - return outputs; -} - -} // namespace mgb::imperative::python diff --git a/imperative/python/src/trace.h b/imperative/python/src/trace.h deleted file mode 100644 index bd79ad02..00000000 --- a/imperative/python/src/trace.h +++ /dev/null @@ -1,28 +0,0 @@ -/** - * \file imperative/python/src/trace.h - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - */ - -#include -#include "./tensor.h" - -namespace mgb::imperative::python { - -class TraceReadError : public std::exception { -public: - explicit TraceReadError(const char* m) : message{m} {} - const char* what() const noexcept override { return message.c_str(); } - -private: - std::string message = ""; -}; - -apply_result_t apply_trace(ApplyContext& ctx); - -} // namespace mgb::imperative::python diff --git a/imperative/python/src/trace_info.h b/imperative/python/src/trace_info.h deleted file mode 100644 index f6a65615..00000000 --- a/imperative/python/src/trace_info.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * \file imperative/python/src/trace_info.h - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - */ - -#include "Python.h" -#include "inttypes.h" - -namespace mgb::imperative::python { - -struct TraceInfo { - int64_t mixin_handle = -1; - bool recording = false; - - // refer to CompiledTensorProxy in tracing.py, works from second trace step - PyObject* compiled_info = nullptr; - // refer to TensorInfo in tracing.py, only works in first trace step - PyObject* trace_mixin_info = nullptr; - - TraceInfo() = default; - - TraceInfo& operator=(const TraceInfo& that) { - mixin_handle = that.mixin_handle; - recording = that.recording; - - trace_mixin_info = that.trace_mixin_info; - Py_XINCREF(trace_mixin_info); - compiled_info = that.compiled_info; - Py_XINCREF(compiled_info); - - return *this; - } - - ~TraceInfo() { - Py_XDECREF(trace_mixin_info); - Py_XDECREF(compiled_info); - } - -private: - TraceInfo(const TraceInfo& that) = default; -}; - -} // namespace mgb::imperative::python diff --git a/imperative/python/test/conftest.py b/imperative/python/test/conftest.py index 518d0584..4424d5b3 100644 --- a/imperative/python/test/conftest.py +++ b/imperative/python/test/conftest.py @@ -16,10 +16,6 @@ import megengine.module from megengine import Parameter from megengine.core._imperative_rt.core2 import sync from megengine.device import get_device_count -from megengine.experimental.autograd import ( - disable_higher_order_directive, - enable_higher_order_directive, -) from megengine.jit import trace as _trace from megengine.module import Linear, Module @@ -45,13 +41,3 @@ def skip_distributed(request): platform.system() ) ) - - -@pytest.fixture(autouse=True) -def resolve_require_higher_order_directive(request): - marker = request.node.get_closest_marker("require_higher_order_directive") - if marker: - enable_higher_order_directive() - yield - if marker: - disable_higher_order_directive() diff --git a/imperative/python/test/integration/test_trace_dump.py b/imperative/python/test/integration/test_trace_dump.py index d7deeab8..f822e518 100644 --- a/imperative/python/test/integration/test_trace_dump.py +++ b/imperative/python/test/integration/test_trace_dump.py @@ -146,5 +146,5 @@ def test_dump_bn_train_mode(): data = mge.tensor(np.random.random((10, 10, 10, 10))) bn_train(data) - with pytest.raises(AssertionError): + with pytest.raises(RuntimeError): bn_train.dump("test.mge") diff --git a/imperative/python/test/unit/autodiff/test_grad_manger.py b/imperative/python/test/unit/autodiff/test_grad_manger.py index a288ba26..eb1f5a4c 100644 --- a/imperative/python/test/unit/autodiff/test_grad_manger.py +++ b/imperative/python/test/unit/autodiff/test_grad_manger.py @@ -17,7 +17,7 @@ import megengine.distributed as dist import megengine.functional as F import megengine.module as M import megengine.optimizer as optim -from megengine.autodiff import GradManager +from megengine.autodiff import Function, GradManager from megengine.jit import trace @@ -214,7 +214,7 @@ def test_remote_grad(trace_mode): x = dist.functional.remote_recv(rank - 1) y = m(x) if rank != size - 1: - dist.functional.remote_send(y, dest_rank=rank + 1) + x = dist.functional.remote_send(y, dest_rank=rank + 1) gm.backward() else: y = y.mean() @@ -224,7 +224,7 @@ def test_remote_grad(trace_mode): if trace_mode is not None: train_func = trace(symbolic=trace_mode)(train_func) - for i in range(3): + for i in range(1): train_func(x) worker() @@ -340,7 +340,6 @@ def test_broadcast_grad(trace_mode): worker() -@pytest.mark.require_higher_order_directive() def test_2nd_grad_with_manager(): x_np = np.random.rand(10).astype("float32") x = mge.tensor(x_np) @@ -359,7 +358,6 @@ def test_2nd_grad_with_manager(): ) -@pytest.mark.require_higher_order_directive() def test_grad_manager_group(): x_np = np.random.rand(10).astype("float32") x = mge.tensor(x_np) @@ -376,7 +374,6 @@ def test_grad_manager_group(): x.grad = None -@pytest.mark.require_higher_order_directive() def test_grad_manager_group_visibility(): x_np = np.random.rand(10).astype("float32") x = mge.tensor(x_np) @@ -392,7 +389,6 @@ def test_grad_manager_group_visibility(): np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) -@pytest.mark.require_higher_order_directive() def test_grad_manager_visibility_by_order(): x_np = np.random.rand(10).astype("float32") x = mge.tensor(x_np) @@ -410,7 +406,6 @@ def test_grad_manager_visibility_by_order(): np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) -@pytest.mark.require_higher_order_directive() @pytest.mark.parametrize("target", [F.cos, F.sin, lambda x: x * 2 + 1]) def test_emulate_forward_mode_with_reverse_mode(target): def jvp(inp, expr): @@ -434,3 +429,43 @@ def test_emulate_forward_mode_with_reverse_mode(target): np.testing.assert_almost_equal(y.numpy(), y1.numpy(), decimal=5) np.testing.assert_almost_equal(dy.numpy(), dy1.numpy(), decimal=3) + + +def test_2nd_grad_with_custom_gradient(): + class MySin(Function): + def forward(self, x): + self.inp = x + x = mge.Tensor(x.numpy()) + y = F.sin(x) + return y + + def backward(self, dy): + dx = F.cos(self.inp) * dy + return dx + + class MyCos(Function): + def forward(self, x): + self.inp = x + x = mge.Tensor(x.numpy()) + y = F.cos(x) + return y + + def backward(self, dy): + dx = -MySin()(self.inp) * dy + return dx + + x_np = np.random.rand(10).astype("float32") + x = mge.tensor(x_np) + + gm = GradManager().attach([x]) + gm2 = GradManager().attach([x]) + + with gm: + with gm2: + y = MyCos()(x) + gm2.backward(y) + np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) + gm.backward(x.grad) + np.testing.assert_almost_equal( + x.grad.numpy(), -np.sin(x_np) - np.cos(x_np), decimal=5 + ) diff --git a/imperative/python/test/unit/core/test_autodiff.py b/imperative/python/test/unit/core/test_autodiff.py index 192038b3..355667af 100644 --- a/imperative/python/test/unit/core/test_autodiff.py +++ b/imperative/python/test/unit/core/test_autodiff.py @@ -7,8 +7,6 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import gc -import platform -import weakref import numpy as np import pytest @@ -60,24 +58,20 @@ def test_dist_grad(): def worker(): rank = dist.get_rank() if rank == 0: - grad = Grad() - - x = as_tensor(x_np) - grad.wrt(x, callback=save_to(x)) - # need a placeholder to trace operator - remote_send(x, 1) - recv_x = remote_recv(1) - y = recv_x * recv_x - - grad([y], [as_tensor(np.ones_like(x_np))]) + with Grad() as grad: + x = as_tensor(x_np) + grad.wrt(x, callback=save_to(x)) + # need a placeholder to trace operator + remote_send(x, 1) + recv_x = remote_recv(1) + y = recv_x * recv_x + grad([y], [as_tensor(np.ones_like(x_np))]) np.testing.assert_almost_equal(x.grad.numpy(), x.numpy() * 2) elif rank == 1: - grad = Grad() - - recv_x = remote_recv(0) - remote_send(recv_x, 0) - - grad([], []) + with Grad() as grad: + recv_x = remote_recv(0) + remote_send(recv_x, 0) + grad([], []) worker() @@ -86,11 +80,11 @@ def test_grad(): x_np = np.random.rand(10).astype("float32") x = as_tensor(x_np) - grad = Grad().wrt(x, callback=save_to(x)) - - y = cos(x) + with Grad() as grad: + grad.wrt(x, callback=save_to(x)) + y = cos(x) + grad(y, as_tensor(np.ones_like(x_np))) - grad(y, as_tensor(np.ones_like(x_np))) np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np)) @@ -98,12 +92,12 @@ def test_grad_2(): x_np = np.random.rand(10).astype("float32") x = as_tensor(x_np) - grad = Grad().wrt(x, callback=save_to(x)) + with Grad() as grad: + grad.wrt(x, callback=save_to(x)) + y = mul(x, x) + y = mul(y, y) + grad(y, as_tensor(np.ones_like(x_np))) - y = mul(x, x) - y = mul(y, y) - - grad(y, as_tensor(np.ones_like(x_np))) np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) @@ -113,32 +107,31 @@ def test_2nd_grad(): x = as_tensor(x_np) ones = as_tensor(np.ones_like(x_np)) - grad = Grad().wrt(x, callback=save_to(x)) - grad._priority = -1 - grad2 = Grad().wrt(x, callback=save_to(x)) - grad2._priority = 0 - - y = cos(x) + with Grad("grad2") as grad2: + with Grad("grad") as grad: + grad2.wrt(x, callback=save_to(x)) + grad.wrt(x, callback=save_to(x)) + y = cos(x) + grad(y, ones) + z = x.grad + np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) - grad(y, ones) - z = x.grad - np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) + x.grad = None + grad2(z, ones) - x.grad = None - grad2(z, ones) - np.testing.assert_almost_equal(x.grad.numpy(), -np.cos(x_np), decimal=5) + np.testing.assert_almost_equal(x.grad.numpy(), -np.cos(x_np), decimal=5) def test_grad_with_tensor_wrapper(): x_np = np.random.rand(10).astype("float32") x = mge.Tensor(x_np) - grad = Grad().wrt(x, callback=save_to(x)) + with Grad() as grad: + grad.wrt(x, callback=save_to(x)) + y = mul(x, x) + y = mul(y, y) + grad(y, mge.Tensor(np.ones_like(x_np))) - y = mul(x, x) - y = mul(y, y) - - grad(y, mge.Tensor(np.ones_like(x_np))) np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) @@ -162,18 +155,21 @@ def test_release(): @check def _(): - g = Grad().wrt(x) - y = x * x - g(y, dy) + with Grad() as g: + g.wrt(x) + y = x * x + g(y, dy) @check def _(): - with Grad().wrt(x): + with Grad() as g: + g.wrt(x) pass @check def _(): - with Grad().wrt(x): + with Grad() as g: + g.wrt(x) y = x * x @@ -181,12 +177,12 @@ def test_grad_inplace(): x_np = np.random.rand(10).astype("float32") x = mge.Tensor(x_np) - grad = Grad().wrt(x, callback=save_to(x)) + with Grad() as grad: + grad.wrt(x, callback=save_to(x)) + y = mul(x, x) + y *= y + grad(y, mge.Tensor(np.ones_like(x_np))) - y = mul(x, x) - y *= y - - grad(y, mge.Tensor(np.ones_like(x_np))) np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) @@ -196,11 +192,11 @@ def test_identity(): dy_np = np.random.rand(*x.shape).astype("float32") dy = mge.Tensor(dy_np) - grad = Grad().wrt(x, callback=save_to(x)) - - (y,) = apply(Identity(), x) + with Grad() as grad: + grad.wrt(x, callback=save_to(x)) + (y,) = apply(Identity(), x) + grad(y, dy) - grad(y, dy) np.testing.assert_array_equal(x.grad.numpy(), dy_np) @@ -220,15 +216,14 @@ def test_elemwise_add(): refs["y"] = TensorWeakRef(y) return x + y - grad = Grad().wrt(x, callback=save_to(x)) - - z = f(x, y) - del y + with Grad() as grad: + grad.wrt(x, callback=save_to(x)) + z = f(x, y) + del y + for k, r in refs.items(): + assert r() is None + grad(z, dz) - for k, r in refs.items(): - assert r() is None - - grad(z, dz) np.testing.assert_almost_equal(x.grad.numpy(), dz_np.sum(0) * 2, decimal=5) @@ -245,13 +240,12 @@ def test_elemwise_relu(): refs["x"] = TensorWeakRef(x) return relu(x) - grad = Grad().wrt(x, callback=save_to(x)) - - z = f(x) - - assert refs["x"]() is None + with Grad() as grad: + grad.wrt(x, callback=save_to(x)) + z = f(x) + assert refs["x"]() is None + grad(z, dz) - grad(z, dz) np.testing.assert_almost_equal(x.grad.numpy(), [2.0, 0]) @@ -269,21 +263,21 @@ def test_reshape(): x_np = np.random.rand(2, 5).astype("float32") x = mge.Tensor(x_np) - grad = Grad().wrt(x, callback=save_to(x)) + with Grad() as grad: + grad.wrt(x, callback=save_to(x)) + refs = {} - refs = {} + def f(x): + x = x * 1 + y = x.reshape(5, 2) + refs["x"] = TensorWeakRef(x) + return y - def f(x): - x = x * 1 - y = x.reshape(5, 2) - refs["x"] = TensorWeakRef(x) - return y + y = f(x) + for _, r in refs.items(): + assert r() is None + grad(y, F.ones_like(y)) - y = f(x) - for _, r in refs.items(): - assert r() is None - - grad(y, F.ones_like(y)) np.testing.assert_equal(np.ones((2, 5), dtype=np.float32), x.grad.numpy()) @@ -291,21 +285,21 @@ def test_subtensor(): x_np = np.random.rand(3, 3).astype("float32") x = mge.Tensor(x_np) - grad = Grad().wrt(x, callback=save_to(x)) - - refs = {} + with Grad() as grad: + grad.wrt(x, callback=save_to(x)) + refs = {} - def f(x): - x = x * 1 - y = x[1:-1, :2] - refs["x"] = TensorWeakRef(x) - return y + def f(x): + x = x * 1 + y = x[1:-1, :2] + refs["x"] = TensorWeakRef(x) + return y - y = f(x) - for _, r in refs.items(): - assert r() is None + y = f(x) + for _, r in refs.items(): + assert r() is None + grad(y, F.ones_like(y)) - grad(y, F.ones_like(y)) np.testing.assert_equal( np.array([[0, 0, 0], [1, 1, 0], [0, 0, 0]], dtype=np.float32), x.grad.numpy() ) @@ -315,21 +309,21 @@ def test_IndexingMultiAxisVec(): x_np = np.random.rand(3, 3).astype("float32") x = mge.Tensor(x_np) - grad = Grad().wrt(x, callback=save_to(x)) + with Grad() as grad: + grad.wrt(x, callback=save_to(x)) + refs = {} - refs = {} + def f(x): + x = x * 1 + y = x[[0, 2], [0, 2]] + refs["x"] = TensorWeakRef(x) + return y - def f(x): - x = x * 1 - y = x[[0, 2], [0, 2]] - refs["x"] = TensorWeakRef(x) - return y + y = f(x) + for _, r in refs.items(): + assert r() is None + grad(y, F.ones_like(y)) - y = f(x) - for _, r in refs.items(): - assert r() is None - - grad(y, F.ones_like(y)) np.testing.assert_equal( np.array([[1, 0, 0], [0, 0, 0], [0, 0, 1]], dtype=np.float32), x.grad.numpy() ) @@ -339,21 +333,21 @@ def test_AxisAddRemove(): x_np = np.random.rand(1, 5).astype("float32") x = mge.Tensor(x_np) - grad = Grad().wrt(x, callback=save_to(x)) - - refs = {} + with Grad() as grad: + grad.wrt(x, callback=save_to(x)) + refs = {} - def f(x): - x = x * 1 - y = F.squeeze(F.expand_dims(x, 2), 0) - refs["x"] = TensorWeakRef(x) - return y + def f(x): + x = x * 1 + y = F.squeeze(F.expand_dims(x, 2), 0) + refs["x"] = TensorWeakRef(x) + return y - y = f(x) - for _, r in refs.items(): - assert r() is None + y = f(x) + for _, r in refs.items(): + assert r() is None + grad(y, F.ones_like(y)) - grad(y, F.ones_like(y)) np.testing.assert_equal( np.array([[1, 1, 1, 1, 1]], dtype=np.float32), x.grad.numpy() ) @@ -363,10 +357,11 @@ def test_Broadcast(): x_np = np.random.rand(3, 3, 1).astype("float32") x = mge.Tensor(x_np) - grad = Grad().wrt(x, callback=save_to(x)) - y = F.broadcast_to(x, (3, 3, 10)) + with Grad() as grad: + grad.wrt(x, callback=save_to(x)) + y = F.broadcast_to(x, (3, 3, 10)) + grad(y, F.ones_like(y)) - grad(y, F.ones_like(y)) np.testing.assert_equal(np.ones((3, 3, 1), dtype=np.float32) * 10, x.grad.numpy()) @@ -374,10 +369,11 @@ def test_interpolate_fastpath(): x_np = np.random.rand(3, 3, 32, 32).astype("float32") x = mge.Tensor(x_np) - grad = Grad().wrt(x, callback=save_to(x)) - y = F.vision.interpolate(x, size=(16, 16), mode="bilinear") + with Grad() as grad: + grad.wrt(x, callback=save_to(x)) + y = F.vision.interpolate(x, size=(16, 16), mode="bilinear") + grad(y, F.ones_like(y)) - grad(y, F.ones_like(y)) np.testing.assert_equal(np.ones(x_np.shape, dtype=np.float32) / 4, x.grad.numpy()) @@ -385,10 +381,11 @@ def test_Reduce_sum(): x_np = np.random.rand(3, 3).astype("float32") x = mge.Tensor(x_np) - grad = Grad().wrt(x, callback=save_to(x)) - y = x.sum(axis=0) + with Grad() as grad: + grad.wrt(x, callback=save_to(x)) + y = x.sum(axis=0) + grad(y, F.ones_like(y)) - grad(y, F.ones_like(y)) np.testing.assert_equal(np.ones((3, 3), dtype=np.float32), x.grad.numpy()) @@ -396,10 +393,11 @@ def test_Reduce_mean(): x_np = np.random.rand(3, 3).astype("float32") x = mge.Tensor(x_np) - grad = Grad().wrt(x, callback=save_to(x)) - y = x.mean(axis=0) + with Grad() as grad: + grad.wrt(x, callback=save_to(x)) + y = x.mean(axis=0) + grad(y, F.ones_like(y)) - grad(y, F.ones_like(y)) np.testing.assert_equal(np.ones((3, 3), dtype=np.float32) / 3, x.grad.numpy()) @@ -407,21 +405,21 @@ def test_addAxis(): x_np = np.random.rand(3, 3).astype("float32") x = mge.Tensor(x_np) - grad = Grad().wrt(x, callback=save_to(x)) + with Grad() as grad: + grad.wrt(x, callback=save_to(x)) + refs = {} - refs = {} - - def f(x): - x = x * 1 - y = F.expand_dims(x, [2, 3]) - refs["x"] = TensorWeakRef(x) - return y + def f(x): + x = x * 1 + y = F.expand_dims(x, [2, 3]) + refs["x"] = TensorWeakRef(x) + return y - y = f(x) - for _, r in refs.items(): - assert r() is None + y = f(x) + for _, r in refs.items(): + assert r() is None + grad(y, F.ones_like(y)) - grad(y, F.ones_like(y)) np.testing.assert_equal(np.ones((3, 3), dtype=np.float32), x.grad.numpy()) @@ -429,21 +427,21 @@ def test_removeAxis(): x_np = np.random.rand(3, 3, 1, 1).astype("float32") x = mge.Tensor(x_np) - grad = Grad().wrt(x, callback=save_to(x)) + with Grad() as grad: + grad.wrt(x, callback=save_to(x)) + refs = {} - refs = {} + def f(x): + x = x * 1 + y = F.squeeze(x, [2, 3]) + refs["x"] = TensorWeakRef(x) + return y - def f(x): - x = x * 1 - y = F.squeeze(x, [2, 3]) - refs["x"] = TensorWeakRef(x) - return y - - y = f(x) - for _, r in refs.items(): - assert r() is None + y = f(x) + for _, r in refs.items(): + assert r() is None + grad(y, F.ones_like(y)) - grad(y, F.ones_like(y)) np.testing.assert_equal(np.ones((3, 3, 1, 1), dtype=np.float32), x.grad.numpy()) @@ -452,11 +450,14 @@ def test_dot(): x = mge.Tensor(x) u = F.ones((2,)) v = F.ones((2,)) - grad = Grad().wrt(x, callback=save_to(x)) - def f(x): - return F.dot(u, F.matmul(x, v)) + with Grad() as grad: + grad.wrt(x, callback=save_to(x)) + + def f(x): + return F.dot(u, F.matmul(x, v)) + + y = f(x) + grad(y, F.ones_like(y)) - y = f(x) - grad(y, F.ones_like(y)) np.testing.assert_equal(np.ones((2, 2), dtype=np.float32), x.grad.numpy()) diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 3176c057..2c9968ac 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -267,25 +267,27 @@ def _gen_roi_inp(): def test_roi_align(): inp_feat, rois = _gen_roi_inp() - grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat)) - - output_shape = (7, 7) - out_feat = F.vision.roi_align( - inp_feat, - rois, - output_shape=output_shape, - mode="average", - spatial_scale=1.0 / 4, - sample_points=2, - aligned=True, - ) - assert make_shape_tuple(out_feat.shape) == ( - rois.shape[0], - inp_feat.shape[1], - *output_shape, - ) + with Grad() as grad: + grad.wrt(inp_feat, callback=_save_to(inp_feat)) + + output_shape = (7, 7) + out_feat = F.vision.roi_align( + inp_feat, + rois, + output_shape=output_shape, + mode="average", + spatial_scale=1.0 / 4, + sample_points=2, + aligned=True, + ) + assert make_shape_tuple(out_feat.shape) == ( + rois.shape[0], + inp_feat.shape[1], + *output_shape, + ) + + grad(out_feat, tensor(F.ones_like(out_feat))) - grad(out_feat, tensor(F.ones_like(out_feat))) assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape) @@ -307,20 +309,23 @@ def _gen_correlation(random=True, constant=1, image_shape=(2, 1, 160, 160)): def test_correlation(): ##test case 0 check the grad shape data1, data2 = _gen_correlation() - grad = Grad().wrt(data1, callback=_save_to(data1)) - out_feat = F.vision.correlation( - data1, - data2, - kernel_size=5, - max_displacement=4, - stride1=2, - stride2=2, - pad_size=2, - is_multiply=True, - ) + with Grad() as grad: + grad.wrt(data1, callback=_save_to(data1)) + + out_feat = F.vision.correlation( + data1, + data2, + kernel_size=5, + max_displacement=4, + stride1=2, + stride2=2, + pad_size=2, + is_multiply=True, + ) + + grad(out_feat, tensor(F.ones_like(out_feat))) - grad(out_feat, tensor(F.ones_like(out_feat))) assert make_shape_tuple(data1.grad.shape) == make_shape_tuple(data1.shape) ##test case 1 from https://github.com/NVIDIA/flownet2-pytorch/issues/194 @@ -391,32 +396,36 @@ def test_correlation(): def test_roi_pooling(): inp_feat, rois = _gen_roi_inp() - grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat)) - output_shape = (7, 7) - out_feat = F.vision.roi_pooling( - inp_feat, rois, output_shape=output_shape, mode="max", scale=1.0 / 4, - ) - assert make_shape_tuple(out_feat.shape) == ( - rois.shape[0], - inp_feat.shape[1], - *output_shape, - ) + with Grad() as grad: + grad.wrt(inp_feat, callback=_save_to(inp_feat)) + output_shape = (7, 7) + out_feat = F.vision.roi_pooling( + inp_feat, rois, output_shape=output_shape, mode="max", scale=1.0 / 4, + ) + assert make_shape_tuple(out_feat.shape) == ( + rois.shape[0], + inp_feat.shape[1], + *output_shape, + ) + + grad(out_feat, tensor(F.ones_like(out_feat))) - grad(out_feat, tensor(F.ones_like(out_feat))) assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape) def test_adaptive_avg_pool2d(): inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4)) oshp = (2, 2) - grad = Grad().wrt(inp, callback=_save_to(inp)) - outp = F.adaptive_avg_pool2d(inp, oshp,) - assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,) - np.testing.assert_equal( - outp.numpy(), np.array([[[[2.5, 4.5], [10.5, 12.5]]]], dtype=np.float32) - ) + with Grad() as grad: + grad.wrt(inp, callback=_save_to(inp)) + outp = F.adaptive_avg_pool2d(inp, oshp,) + assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,) + np.testing.assert_equal( + outp.numpy(), np.array([[[[2.5, 4.5], [10.5, 12.5]]]], dtype=np.float32) + ) + + grad(outp, tensor(F.ones_like(outp))) - grad(outp, tensor(F.ones_like(outp))) assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape) np.testing.assert_equal( inp.grad.numpy(), @@ -439,14 +448,16 @@ def test_adaptive_avg_pool2d(): def test_adaptive_max_pool2d(): inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4)) oshp = (2, 2) - grad = Grad().wrt(inp, callback=_save_to(inp)) - outp = F.adaptive_max_pool2d(inp, oshp,) - assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,) - np.testing.assert_equal( - outp.numpy(), np.array([[[[5, 7], [13, 15]]]], dtype=np.float32) - ) + with Grad() as grad: + grad.wrt(inp, callback=_save_to(inp)) + outp = F.adaptive_max_pool2d(inp, oshp,) + assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,) + np.testing.assert_equal( + outp.numpy(), np.array([[[[5, 7], [13, 15]]]], dtype=np.float32) + ) + + grad(outp, tensor(F.ones_like(outp))) - grad(outp, tensor(F.ones_like(outp))) assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape) np.testing.assert_equal( inp.grad.numpy(), diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index d45944f4..bc764537 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -351,7 +351,7 @@ def test_expand_dims_for_scalar(): for axis in [1, -2, (1, 2), (-2, -3)]: np.testing.assert_raises(np.AxisError, np.expand_dims, x, axis) - np.testing.assert_raises(AssertionError, F.expand_dims, xx, axis) + np.testing.assert_raises(RuntimeError, F.expand_dims, xx, axis) @pytest.mark.parametrize("is_varnode", [True, False]) diff --git a/imperative/python/test/unit/jit/test_tracing.py b/imperative/python/test/unit/jit/test_tracing.py index de2f2da9..2854e653 100644 --- a/imperative/python/test/unit/jit/test_tracing.py +++ b/imperative/python/test/unit/jit/test_tracing.py @@ -9,6 +9,7 @@ import inspect import io import itertools +import random from tempfile import mkstemp import numpy as np @@ -25,7 +26,7 @@ from megengine.core.ops import builtin as ops from megengine.core.ops.builtin import Elemwise from megengine.core.tensor.utils import isscalar from megengine.functional import exp, log -from megengine.jit import GraphOptimizationConfig, exclude_from_trace, trace +from megengine.jit import GraphOptimizationConfig, TraceError, exclude_from_trace, trace from megengine.module import Module from megengine.random import normal, uniform from megengine.utils.naming import AutoNaming @@ -464,36 +465,92 @@ def test_trace_warp_perspective(): f(x, M) -def test_raise_on_trace(): - step_count = 0 - catch_count = 0 - bad_step = 10 +@pytest.mark.parametrize( + "normal_expr, mismatch_expr, reason", + [ + ("a + b + c", "a + b - c", "operator mismatch"), + ("a + b + 1", "a + b + 2", "tensors not equals"), + ("((a + b), (b + c))[0]", "a + b", "mismature end"), + ("a + b + c", "c + (a + b)", "expect internal node, got external"), + ("c + (a + b)", "a + b + c", "expect external node, got internal"), + ("a + b + c", "a + b + c + c", "too many instructions"), + ("((a + b), (b + c))[1]", "((a + b), (b + c))[0]", "data unreadable"), + ("((a + b), (b + c))[1] + a", "((a + b), (b + c))[0] + a", "input id mismatch"), + ], +) +def test_trace_mismatch(normal_expr, mismatch_expr, reason): + a = tensor([1, 2, 3, 4]) + b = tensor([5, 6, 7, 8]) + c = tensor([9, 0, 1, 2]) + + mismatch = False + + @trace(symbolic=True) + def fn(a, b, c): + if not mismatch: + result = eval(normal_expr) + else: + result = eval(mismatch_expr) + return result + + for i in range(20): + try: + d = fn(a, b, c) + except TraceError as e: + assert mismatch + assert str(e) == "trace error because {}".format(reason) + except: + pytest.fail("unexpected trace error") + else: + assert not mismatch + np.testing.assert_equal(d.numpy(), eval(normal_expr).numpy()) + mismatch = random.random() > 0.8 - class CatchMe(Exception): - pass +def test_exception_in_trace(): a = tensor([1, 2, 3, 4]) b = tensor([5, 6, 7, 8]) c = tensor([9, 0, 1, 2]) - @trace - def add_abc(a, b, c): - ps = a + b - result = ps + c - if step_count == bad_step: - raise CatchMe("catch me") + mismatch = False + + exc = Exception() + + @trace(symbolic=True) + def fn(a, b, c): + result = a + b + if not mismatch: + result += c + else: + raise exc return result - for i in range(100): + for i in range(20): try: - d = add_abc(a, b, c) - except CatchMe as e: - catch_count += 1 + d = fn(a, b, c) + except TraceError as e: + pytest.fail("unexpected trace error") + except Exception as e: + assert mismatch + assert e is exc else: + assert not mismatch np.testing.assert_equal(d.numpy(), (a + b + c).numpy()) - step_count += 1 + mismatch = random.random() > 0.8 - assert catch_count == 1 + +def test_graph_error(): + a = tensor(np.arange(8).reshape((2, 4))) + b = tensor(np.arange(8).reshape((2, 4))) + + @trace(symbolic=True) + def fn(a, b): + return a + b + + fn(a, b) + with pytest.raises(RuntimeError): + fn(a, b.transpose()) + fn(a, b) @pytest.mark.parametrize("trace_mode", [False, True]) @@ -653,9 +710,10 @@ def test_trace_jit_config(): x = tensor(2) y = func(x) - func._compile() + y = func(x) + # func._compile() - options = func._graph.options + options = func._trace.options mapping = {None: 0, False: 1, True: 2} assert options.graph_opt.jit == 0 assert options.graph_opt.jit_config.fuse_dimshuffle == mapping[fuse_dimshuffle] diff --git a/imperative/python/test/unit/quantization/test_fake_quant.py b/imperative/python/test/unit/quantization/test_fake_quant.py index 96f6f4fc..a2518862 100644 --- a/imperative/python/test/unit/quantization/test_fake_quant.py +++ b/imperative/python/test/unit/quantization/test_fake_quant.py @@ -82,9 +82,10 @@ def test_tqt(): x = mge.tensor(x, dtype="float32") s = mge.tensor(s, dtype="float32") g_y = mge.tensor(g_y, dtype="float32") - grad = Grad().wrt(x, s, callback=cb) - y = tqt_forward(-127, 127, x, s) - grad(y, g_y) + with Grad() as grad: + grad.wrt(x, s, callback=cb) + y = tqt_forward(-127, 127, x, s) + grad(y, g_y) g_x, g_s = g np.testing.assert_allclose(y.numpy(), y_np, rtol=1e-5, atol=1e-5) @@ -131,14 +132,16 @@ def test_fakequant(): # test backward x = tensor(inp_data, dtype=np.float32) - grad = Grad().wrt(x, callback=_save_to(x)) - y = fake_quant_tensor(x, qparams) - grad(y, tensor(F.ones_like(x))) + with Grad() as grad: + grad.wrt(x, callback=_save_to(x)) + y = fake_quant_tensor(x, qparams) + grad(y, tensor(F.ones_like(x))) x1 = tensor(inp_data, dtype=np.float32) - grad = Grad().wrt(x1, callback=_save_to(x1)) - y1 = fake_quant_tensor_gt(x1, scale, zero_point, qmin, qmax) - grad(y1, tensor(F.ones_like(x1))) + with Grad() as grad: + grad.wrt(x1, callback=_save_to(x1)) + y1 = fake_quant_tensor_gt(x1, scale, zero_point, qmin, qmax) + grad(y1, tensor(F.ones_like(x1))) assert np.allclose(x.grad.numpy(), x1.grad.numpy()) assert make_shape_tuple(x.grad.shape) == make_shape_tuple(x1.grad.shape) @@ -237,9 +240,10 @@ def test_lsq(): grad_s = mge.tensor(grad_s, dtype="float32") g_y = mge.tensor(g_y, dtype="float32") - grad = Grad().wrt(x, s, callback=cb) - y = lsq_forward(-127, 127, x, s, zero_point, grad_s) - grad(y, g_y) + with Grad() as grad: + grad.wrt(x, s, callback=cb) + y = lsq_forward(-127, 127, x, s, zero_point, grad_s) + grad(y, g_y) g_x, g_s = g np.testing.assert_allclose(y.numpy(), y_np, rtol=1e-7, atol=1e-7) diff --git a/imperative/python/test/unit/random/test_rng.py b/imperative/python/test/unit/random/test_rng.py index 1083e947..f872a0c2 100644 --- a/imperative/python/test/unit/random/test_rng.py +++ b/imperative/python/test/unit/random/test_rng.py @@ -430,9 +430,10 @@ def test_ShuffleRNG(): n, m = 6, 3 arr = np.arange(n * m) out0 = Tensor(arr, dtype="float32") - grad = Grad().wrt(out0, callback=cb) - random.shuffle(out0) - grad(out0, F.ones_like(out0)) + with Grad() as grad: + grad.wrt(out0, callback=cb) + random.shuffle(out0) + grad(out0, F.ones_like(out0)) m1 = RNG(seed=111, device="xpu0") m2 = RNG(seed=111, device="xpu1") m3 = RNG(seed=222, device="xpu0")