# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import functools import heapq import itertools import typing import weakref import numpy as np import megengine as mge from .._imperative_rt import core2 from ..ops.builtin import Elemwise, OpDef, RemoteSend from ..ops.special import Const from ..tensor.core import TensorBase, TensorWrapperBase, apply from ..tensor.function import Function from . import builtin_op_utils """ Some notes: 1. Initialize the optimizer: for each trainable parameter: call wrt(param, callback) Each parameter tensor will be assciated with a Tracer object saved in Tensor._extra_data 2. Tracer has one member: node, which is a VariableNode 3. VariableNode has a OpNode member: opnode 4. OpNode has four members: a. id b. inputs, which is made of VariableNode c. outputs, which are weakref's to VariableNode d. backward: call back function e. has_grad_fn: call has_grad_fn(opnode, reached) to check grad exist f. backward_allow_noinput: whether backward allow noinput """ _grad_count = 0 _grad_manager_dict = weakref.WeakValueDictionary() def get_grad_managers(): return [_grad_manager_dict[key] for key in _grad_manager_dict] def add(a, b): (c,) = apply(Elemwise(Elemwise.Mode.ADD), a, b) return c def get_tensor(x): # use recursion to avoid infinite loop if isinstance(x, Tensor): return x try: x = x.__wrapped__ except AttributeError: raise TypeError(type(x)) return get_tensor(x) class clearable: __cleared = False def __bool__(self): return not self.__cleared def clear(self): self.__dict__.clear() self.__cleared = True class OpNode(clearable): """ OpNode saves all the information to form the computational graph. """ def __init__(self): self.id = None self.inputs = None # Could be VariableNode self.outputs = None # Could be VariableNode self.backward = None self.has_grad_fn = None self.backward_allow_noinput = False class VariableNode(clearable): """ VariableNode saves OpNode and callback. FIXME!!! Explain manager and owner """ def __init__(self, manager, owner, opnode=None, callback=None): # manager is Grad type self.manager = weakref.ref(manager) # owner is Tensor type self.owner = weakref.ref(owner) self.opnode = opnode self.callback = callback class Tracer(clearable, TensorBase): def __init__(self, node=None): """ type(node) is VariableNode """ self.node = node @functools.singledispatch def check_backward_allow_noinput(op: OpDef): return False @functools.singledispatch def get_op_has_grad_fn(op: OpDef): assert 0 @get_op_has_grad_fn.register(OpDef) def _(op: OpDef): return default_has_grad_fn @get_op_has_grad_fn.register(Function) def _(op: Function): return default_has_grad_fn def default_has_grad_fn(opnode, reached): for v in opnode.outputs: if v() in reached: return True return False @apply.register() def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]): args = tuple(i if isinstance(i, Tracer) else None for i in args) input_requires_grad = list(map(bool, args)) if not any(input_requires_grad): return ctx = get_context() manager = None assert len(ctx.inputs) == len(args) for i, j in zip(ctx.inputs, args): if j: j = j.node assert i is j.owner() if manager is None: manager = j.manager() assert manager else: assert manager is j.manager() if not manager._enabled: return # register backward method # tuple of backward functions corresponding to dy / dx_i # None means y is not a function of x_i backward, output_need_grad = builtin_op_utils.builtin_op_get_backward_fn( op, ctx.inputs, ctx.outputs, input_requires_grad ) assert len(ctx.outputs) == len(output_need_grad) if not any(output_need_grad): return opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs) if isinstance(op, RemoteSend): manager.remote_send_cache.append(opnode) opnode.backward = backward outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)] opnode.backward_allow_noinput = check_backward_allow_noinput(op) opnode.has_grad_fn = get_op_has_grad_fn(op) return tuple(outputs) @apply.register() def _(op: Const, *_: typing.Optional[Tracer]): return None class Grad: def __init__(self): self._impl = core2.GradKey() def wrt(self, *tensors, callback=None): for x in tensors: self._impl.attach(x, callback) return self def __call__(self, ys, dys): from collections.abc import Sequence if not isinstance(ys, Sequence): ys = [ys] if not isinstance(dys, Sequence): dys = [dys] core2.backward(self._impl, ys, dys) def __enter__(self): return self def __exit__(self, _1, _2, _3): del self._impl