GitOrigin-RevId: ca5a6ed8eb
release-1.2
@@ -12,6 +12,7 @@ import itertools | |||||
import numpy as np | import numpy as np | ||||
from .._imperative_rt import TensorAttr, imperative | from .._imperative_rt import TensorAttr, imperative | ||||
from .._imperative_rt.core2 import apply | |||||
from ..ops.builtin import ( | from ..ops.builtin import ( | ||||
Broadcast, | Broadcast, | ||||
Elemwise, | Elemwise, | ||||
@@ -25,37 +26,6 @@ from ..ops.builtin import ( | |||||
Subtensor, | Subtensor, | ||||
) | ) | ||||
from ..ops.special import Const | from ..ops.special import Const | ||||
from ..tensor.core import apply | |||||
from ..tensor.function import Function | |||||
@functools.singledispatch | |||||
def builtin_op_get_backward_fn(op: OpDef, inputs, outputs, input_requires_grad): | |||||
assert 0 | |||||
@builtin_op_get_backward_fn.register(OpDef) | |||||
def _(op: OpDef, inputs, outputs, input_requires_grad): | |||||
if isinstance(op, Reshape): | |||||
grad_fn = reshape_grad_fn | |||||
elif isinstance(op, Subtensor): | |||||
grad_fn = subtensor_grad_fn | |||||
elif isinstance(op, IndexingMultiAxisVec): | |||||
grad_fn = indexingMultiAxisVec_grad_fn | |||||
elif isinstance(op, Broadcast) or ( | |||||
isinstance(op, Elemwise) and op.mode == Elemwise.Mode.ADD | |||||
): | |||||
grad_fn = elemwise_add_grad_fn | |||||
elif isinstance(op, Reduce) and op.mode == Reduce.Mode.SUM: | |||||
grad_fn = reduce_sum_grad_fn | |||||
else: | |||||
grad_fn = default_grad_fn | |||||
return grad_fn(op, inputs, outputs, input_requires_grad) | |||||
@builtin_op_get_backward_fn.register(Function) | |||||
def _(op: Function, inputs, outputs, input_requires_grad): | |||||
return op.get_backward_fn(), [True,] * len(outputs) | |||||
def default_grad_fn(op, inputs, outputs, input_requires_grad): | def default_grad_fn(op, inputs, outputs, input_requires_grad): | ||||
@@ -19,8 +19,6 @@ import megengine as mge | |||||
from .._imperative_rt import core2, ops | from .._imperative_rt import core2, ops | ||||
from ..ops.builtin import Elemwise, OpDef, RemoteSend | from ..ops.builtin import Elemwise, OpDef, RemoteSend | ||||
from ..ops.special import Const | from ..ops.special import Const | ||||
from ..tensor.core import TensorBase, TensorWrapperBase, apply | |||||
from ..tensor.function import Function | |||||
from . import builtin_op_utils | from . import builtin_op_utils | ||||
""" Some notes: | """ Some notes: | ||||
@@ -48,146 +46,6 @@ def get_grad_managers(): | |||||
return [_grad_manager_dict[key] for key in _grad_manager_dict] | return [_grad_manager_dict[key] for key in _grad_manager_dict] | ||||
def add(a, b): | |||||
(c,) = apply(Elemwise(Elemwise.Mode.ADD), a, b) | |||||
return c | |||||
def get_tensor(x): | |||||
# use recursion to avoid infinite loop | |||||
if isinstance(x, Tensor): | |||||
return x | |||||
try: | |||||
x = x.__wrapped__ | |||||
except AttributeError: | |||||
raise TypeError(type(x)) | |||||
return get_tensor(x) | |||||
class clearable: | |||||
__cleared = False | |||||
def __bool__(self): | |||||
return not self.__cleared | |||||
def clear(self): | |||||
self.__dict__.clear() | |||||
self.__cleared = True | |||||
class OpNode(clearable): | |||||
""" OpNode saves all the information to form the computational graph. | |||||
""" | |||||
def __init__(self): | |||||
self.id = None | |||||
self.inputs = None # Could be VariableNode | |||||
self.outputs = None # Could be VariableNode | |||||
self.backward = None | |||||
self.has_grad_fn = None | |||||
self.backward_allow_noinput = False | |||||
class VariableNode(clearable): | |||||
""" VariableNode saves OpNode and callback. | |||||
FIXME!!! Explain manager and owner | |||||
""" | |||||
def __init__(self, manager, owner, opnode=None, callback=None): | |||||
# manager is Grad type | |||||
self.manager = weakref.ref(manager) | |||||
# owner is Tensor type | |||||
self.owner = weakref.ref(owner) | |||||
self.opnode = opnode | |||||
self.callback = callback | |||||
class Tracer(clearable, TensorBase): | |||||
def __init__(self, node=None): | |||||
""" type(node) is VariableNode | |||||
""" | |||||
self.node = node | |||||
@functools.singledispatch | |||||
def check_backward_allow_noinput(op: OpDef): | |||||
return False | |||||
@functools.singledispatch | |||||
def get_op_has_grad_fn(op: OpDef): | |||||
assert 0 | |||||
@get_op_has_grad_fn.register(OpDef) | |||||
def _(op: OpDef): | |||||
return default_has_grad_fn | |||||
@get_op_has_grad_fn.register(Function) | |||||
def _(op: Function): | |||||
return default_has_grad_fn | |||||
def default_has_grad_fn(opnode, reached): | |||||
for v in opnode.outputs: | |||||
if v() in reached: | |||||
return True | |||||
return False | |||||
@apply.register() | |||||
def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]): | |||||
args = tuple(i if isinstance(i, Tracer) else None for i in args) | |||||
input_requires_grad = list(map(bool, args)) | |||||
if not any(input_requires_grad): | |||||
return | |||||
ctx = get_context() | |||||
manager = None | |||||
assert len(ctx.inputs) == len(args) | |||||
for i, j in zip(ctx.inputs, args): | |||||
if j: | |||||
j = j.node | |||||
assert i is j.owner() | |||||
if manager is None: | |||||
manager = j.manager() | |||||
assert manager | |||||
else: | |||||
assert manager is j.manager() | |||||
if not manager._enabled: | |||||
return | |||||
# register backward method | |||||
# tuple of backward functions corresponding to dy / dx_i | |||||
# None means y is not a function of x_i | |||||
backward, output_need_grad = builtin_op_utils.builtin_op_get_backward_fn( | |||||
op, ctx.inputs, ctx.outputs, input_requires_grad | |||||
) | |||||
assert len(ctx.outputs) == len(output_need_grad) | |||||
if not any(output_need_grad): | |||||
return | |||||
opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs) | |||||
if isinstance(op, RemoteSend): | |||||
manager.remote_send_cache.append(opnode) | |||||
opnode.backward = backward | |||||
outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)] | |||||
opnode.backward_allow_noinput = check_backward_allow_noinput(op) | |||||
opnode.has_grad_fn = get_op_has_grad_fn(op) | |||||
return tuple(outputs) | |||||
@apply.register() | |||||
def _(op: Const, *_: typing.Optional[Tracer]): | |||||
return None | |||||
class Grad: | class Grad: | ||||
def __init__(self): | def __init__(self): | ||||
self._impl = core2.GradKey() | self._impl = core2.GradKey() | ||||
@@ -8,9 +8,6 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import numpy as np | import numpy as np | ||||
# from .._imperative_rt.core2 import Tensor | |||||
from ..tensor.core import OpBase, TensorBase, apply | |||||
class Const: | class Const: | ||||
def __init__(self, value=None, *, dtype=None, device=None): | def __init__(self, value=None, *, dtype=None, device=None): | ||||
@@ -13,12 +13,9 @@ import sys | |||||
import typing | import typing | ||||
from abc import ABC | from abc import ABC | ||||
from .multipledispatch import Dispatcher | |||||
class OpBase(ABC): | |||||
def __call__(self, *args): | |||||
return apply(self, *args) | |||||
class OpBase: | |||||
pass | |||||
class TensorBase: | class TensorBase: | ||||
@@ -27,22 +24,3 @@ class TensorBase: | |||||
class TensorWrapperBase: | class TensorWrapperBase: | ||||
pass | pass | ||||
apply = Dispatcher("apply") | |||||
OpBase.apply = apply | |||||
@apply.register() | |||||
def _(op: OpBase, *args: TensorBase): | |||||
raise NotImplementedError | |||||
@apply.register() | |||||
def _(op: OpBase, *args: TensorWrapperBase): | |||||
assert args | |||||
Wrapper = type(args[0]) | |||||
outputs = apply(op, *(i.__wrapped__ for i in args)) | |||||
assert isinstance(outputs, tuple) | |||||
return tuple(map(Wrapper, outputs)) |
@@ -1,154 +0,0 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from ..ops.builtin import OpDef | |||||
from .core import TensorBase, TensorWrapperBase, apply | |||||
class Function: | |||||
""" | |||||
Defines a block of operations with customizable differentiation. | |||||
The computation should be defined in ``forward`` method, with gradient | |||||
computation defined in ``backward`` method. | |||||
Each instance of ``Function`` should be used only once during forwardding. | |||||
Examples: | |||||
.. testcode:: | |||||
class Sigmoid(Function): | |||||
def forward(self, x): | |||||
y = 1 / (1 + F.exp(-x)) | |||||
self.y = y | |||||
return y | |||||
def backward(self, output_grads): | |||||
y = self.y | |||||
return output_grads * y * (1-y) | |||||
""" | |||||
def __init__(self, *args, **kwargs): | |||||
pass | |||||
def __call__(self, *args): | |||||
ret = apply(self, *args) | |||||
if type(ret) == tuple and len(ret) == 1: | |||||
return ret[0] | |||||
return ret | |||||
def forward(self, *args, **kwargs): | |||||
""" | |||||
Applies operations to ``inputs`` and returns results. It must be overriden by all subclasses. | |||||
:param input: input tensors. | |||||
:return: a tuple of Tensor or a single Tensor. | |||||
.. note:: | |||||
This method should return a tuple of Tensor or a single Tensor representing the output | |||||
of the function. | |||||
""" | |||||
raise NotImplementedError | |||||
def backward(self, *output_grads): | |||||
""" | |||||
Compute the gradient of the forward function. It must be overriden by all subclasses. | |||||
:param output_grads: gradients of outputs that are returned by :meth:`~.function.Function.forward`. | |||||
.. note:: | |||||
In case when some tensors of outputs are not related to loss function, the corresponding | |||||
values in ``output_grads`` would be ``None``. | |||||
.. note:: | |||||
This method should return a tuple which containing the gradients of all inputs, in the same order | |||||
as the ``inputs`` argument of :meth:`~.function.Function.forward` . A ``Tensor`` could be returned | |||||
instead if there is only one input. If users want to stop the propagation of some gradients, | |||||
the corresponding returned values should be set ``None`` . | |||||
""" | |||||
raise NotImplementedError | |||||
def get_backward_fn(self): | |||||
if self.backward is None: | |||||
return None | |||||
def _backward(*output_grads): | |||||
if type(output_grads) is tuple: | |||||
_output_grads = [ | |||||
TensorWrapper(i) if i is not None else i for i in output_grads | |||||
] | |||||
else: | |||||
_output_grads = ( | |||||
TensorWrapper(output_grads) | |||||
if output_grads is not None | |||||
else output_grads, | |||||
) | |||||
ret = self.backward(*_output_grads) | |||||
if type(ret) is not tuple: | |||||
ret = (ret,) | |||||
ret = tuple( | |||||
i.__wrapped__ if isinstance(i, TensorWrapper) else i for i in ret | |||||
) | |||||
return ret | |||||
return _backward | |||||
Function.apply = Function.__call__ | |||||
@apply.register() | |||||
def _(op: Function, *args: TensorWrapperBase): | |||||
assert args | |||||
Wrapper = type(args[0]) | |||||
# compute the value for self define function | |||||
extra_data_dic = {} | |||||
for arg in args: | |||||
extra_data_dic[arg.__wrapped__] = arg.__wrapped__._extra_data | |||||
arg.__wrapped__._extra_data = {} | |||||
rets = op.forward(*args) | |||||
for arg in args: | |||||
arg.__wrapped__._extra_data = extra_data_dic[arg.__wrapped__] | |||||
# update the gradient information for self define function | |||||
inputs = tuple(map(lambda i: i.__wrapped__, args)) | |||||
outputs = ( | |||||
tuple(map(lambda i: i.__wrapped__, rets)) | |||||
if type(rets) is tuple | |||||
else (rets.__wrapped__,) | |||||
) | |||||
for output in outputs: | |||||
if output not in inputs: | |||||
output._extra_data = {} | |||||
with push_context() as ctx: | |||||
ctx.inputs = inputs | |||||
ctx.outputs = outputs | |||||
for k in set().union(*(i._extra_data for i in inputs if isinstance(i, Tensor))): | |||||
ctx.key = k | |||||
data = tuple( | |||||
i._extra_data.get(k) if isinstance(i, Tensor) else i for i in inputs | |||||
) | |||||
# data are instances of Tracer | |||||
# dispatched to apply.add@grad.py | |||||
rets = apply(op, *data) | |||||
if rets is not None: | |||||
assert len(outputs) == len(rets) | |||||
for t, i in zip(outputs, rets): | |||||
t._extra_data[k] = i | |||||
return tuple(map(Wrapper, outputs)) |
@@ -1,53 +0,0 @@ | |||||
# Copyright (c) 2014 Matthew Rocklin | |||||
# | |||||
# All rights reserved. | |||||
# | |||||
# Redistribution and use in source and binary forms, with or without | |||||
# modification, are permitted provided that the following conditions are met: | |||||
# | |||||
# a. Redistributions of source code must retain the above copyright notice, | |||||
# this list of conditions and the following disclaimer. | |||||
# b. Redistributions in binary form must reproduce the above copyright | |||||
# notice, this list of conditions and the following disclaimer in the | |||||
# documentation and/or other materials provided with the distribution. | |||||
# c. Neither the name of multipledispatch nor the names of its contributors | |||||
# may be used to endorse or promote products derived from this software | |||||
# without specific prior written permission. | |||||
# | |||||
# | |||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR | |||||
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT | |||||
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY | |||||
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH | |||||
# DAMAGE. | |||||
# | |||||
# -------------------------------------------------------------------------------------- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# | |||||
# This file has been modified by Megvii ("Megvii Modifications"). | |||||
# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. | |||||
# -------------------------------------------------------------------------------------- | |||||
# This directory is a fork of multipledispatch. | |||||
# | |||||
# Repo: https://github.com/mrocklin/multipledispatch | |||||
# Commit: 9e3c87d0cee57972fd5cc33fe5cacde77c781834 | |||||
# Authors: Matthew Rocklin et al. | |||||
# | |||||
# The original LICENSE file is included in the ACKNOWLEDGEMENT file under | |||||
# MegEngine root directory. | |||||
from .core import dispatch | |||||
from .dispatcher import Dispatcher |
@@ -1,165 +0,0 @@ | |||||
# Copyright (c) 2014 Matthew Rocklin | |||||
# | |||||
# All rights reserved. | |||||
# | |||||
# Redistribution and use in source and binary forms, with or without | |||||
# modification, are permitted provided that the following conditions are met: | |||||
# | |||||
# a. Redistributions of source code must retain the above copyright notice, | |||||
# this list of conditions and the following disclaimer. | |||||
# b. Redistributions in binary form must reproduce the above copyright | |||||
# notice, this list of conditions and the following disclaimer in the | |||||
# documentation and/or other materials provided with the distribution. | |||||
# c. Neither the name of multipledispatch nor the names of its contributors | |||||
# may be used to endorse or promote products derived from this software | |||||
# without specific prior written permission. | |||||
# | |||||
# | |||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR | |||||
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT | |||||
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY | |||||
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH | |||||
# DAMAGE. | |||||
# | |||||
# -------------------------------------------------------------------------------------- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# | |||||
# This file has been modified by Megvii ("Megvii Modifications"). | |||||
# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. | |||||
# -------------------------------------------------------------------------------------- | |||||
from collections import OrderedDict | |||||
from .utils import _toposort, groupby | |||||
from .variadic import isvariadic | |||||
class AmbiguityWarning(Warning): | |||||
pass | |||||
def supercedes(a, b): | |||||
""" A is consistent and strictly more specific than B """ | |||||
if len(a) < len(b): | |||||
# only case is if a is empty and b is variadic | |||||
return not a and len(b) == 1 and isvariadic(b[-1]) | |||||
elif len(a) == len(b): | |||||
return all(map(issubclass, a, b)) | |||||
else: | |||||
# len(a) > len(b) | |||||
p1 = 0 | |||||
p2 = 0 | |||||
while p1 < len(a) and p2 < len(b): | |||||
cur_a = a[p1] | |||||
cur_b = b[p2] | |||||
if not (isvariadic(cur_a) or isvariadic(cur_b)): | |||||
if not issubclass(cur_a, cur_b): | |||||
return False | |||||
p1 += 1 | |||||
p2 += 1 | |||||
elif isvariadic(cur_a): | |||||
assert p1 == len(a) - 1 | |||||
return p2 == len(b) - 1 and issubclass(cur_a, cur_b) | |||||
elif isvariadic(cur_b): | |||||
assert p2 == len(b) - 1 | |||||
if not issubclass(cur_a, cur_b): | |||||
return False | |||||
p1 += 1 | |||||
return p2 == len(b) - 1 and p1 == len(a) | |||||
def consistent(a, b): | |||||
""" It is possible for an argument list to satisfy both A and B """ | |||||
# Need to check for empty args | |||||
if not a: | |||||
return not b or isvariadic(b[0]) | |||||
if not b: | |||||
return not a or isvariadic(a[0]) | |||||
# Non-empty args check for mutual subclasses | |||||
if len(a) == len(b): | |||||
return all(issubclass(aa, bb) or issubclass(bb, aa) for aa, bb in zip(a, b)) | |||||
else: | |||||
p1 = 0 | |||||
p2 = 0 | |||||
while p1 < len(a) and p2 < len(b): | |||||
cur_a = a[p1] | |||||
cur_b = b[p2] | |||||
if not issubclass(cur_b, cur_a) and not issubclass(cur_a, cur_b): | |||||
return False | |||||
if not (isvariadic(cur_a) or isvariadic(cur_b)): | |||||
p1 += 1 | |||||
p2 += 1 | |||||
elif isvariadic(cur_a): | |||||
p2 += 1 | |||||
elif isvariadic(cur_b): | |||||
p1 += 1 | |||||
# We only need to check for variadic ends | |||||
# Variadic types are guaranteed to be the last element | |||||
return isvariadic(cur_a) and p2 == len(b) or isvariadic(cur_b) and p1 == len(a) | |||||
def ambiguous(a, b): | |||||
""" A is consistent with B but neither is strictly more specific """ | |||||
return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a)) | |||||
def ambiguities(signatures): | |||||
""" All signature pairs such that A is ambiguous with B """ | |||||
signatures = list(map(tuple, signatures)) | |||||
return set( | |||||
(a, b) | |||||
for a in signatures | |||||
for b in signatures | |||||
if hash(a) < hash(b) | |||||
and ambiguous(a, b) | |||||
and not any(supercedes(c, a) and supercedes(c, b) for c in signatures) | |||||
) | |||||
def super_signature(signatures): | |||||
""" A signature that would break ambiguities """ | |||||
n = len(signatures[0]) | |||||
assert all(len(s) == n for s in signatures) | |||||
return [max([type.mro(sig[i]) for sig in signatures], key=len)[0] for i in range(n)] | |||||
def edge(a, b, tie_breaker=hash): | |||||
""" A should be checked before B | |||||
Tie broken by tie_breaker, defaults to ``hash`` | |||||
""" | |||||
# A either supercedes B and B does not supercede A or if B does then call | |||||
# tie_breaker | |||||
return supercedes(a, b) and ( | |||||
not supercedes(b, a) or tie_breaker(a) > tie_breaker(b) | |||||
) | |||||
def ordering(signatures): | |||||
""" A sane ordering of signatures to check, first to last | |||||
Topoological sort of edges as given by ``edge`` and ``supercedes`` | |||||
""" | |||||
signatures = list(map(tuple, signatures)) | |||||
edges = [(a, b) for a in signatures for b in signatures if edge(a, b)] | |||||
edges = groupby(lambda x: x[0], edges) | |||||
for s in signatures: | |||||
if s not in edges: | |||||
edges[s] = [] | |||||
edges = OrderedDict((k, [b for a, b in v]) for k, v in edges.items()) | |||||
return _toposort(edges) |
@@ -1,130 +0,0 @@ | |||||
# Copyright (c) 2014 Matthew Rocklin | |||||
# | |||||
# All rights reserved. | |||||
# | |||||
# Redistribution and use in source and binary forms, with or without | |||||
# modification, are permitted provided that the following conditions are met: | |||||
# | |||||
# a. Redistributions of source code must retain the above copyright notice, | |||||
# this list of conditions and the following disclaimer. | |||||
# b. Redistributions in binary form must reproduce the above copyright | |||||
# notice, this list of conditions and the following disclaimer in the | |||||
# documentation and/or other materials provided with the distribution. | |||||
# c. Neither the name of multipledispatch nor the names of its contributors | |||||
# may be used to endorse or promote products derived from this software | |||||
# without specific prior written permission. | |||||
# | |||||
# | |||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR | |||||
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT | |||||
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY | |||||
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH | |||||
# DAMAGE. | |||||
# | |||||
# -------------------------------------------------------------------------------------- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# | |||||
# This file has been modified by Megvii ("Megvii Modifications"). | |||||
# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. | |||||
# -------------------------------------------------------------------------------------- | |||||
import inspect | |||||
import sys | |||||
from .dispatcher import Dispatcher, MethodDispatcher, ambiguity_warn | |||||
global_namespace = dict() | |||||
def dispatch(*types, **kwargs): | |||||
""" Dispatch function on the types of the inputs | |||||
Supports dispatch on all non-keyword arguments. | |||||
Collects implementations based on the function name. Ignores namespaces. | |||||
If ambiguous type signatures occur a warning is raised when the function is | |||||
defined suggesting the additional method to break the ambiguity. | |||||
Examples | |||||
-------- | |||||
>>> @dispatch(int) | |||||
... def f(x): | |||||
... return x + 1 | |||||
>>> @dispatch(float) | |||||
... def f(x): | |||||
... return x - 1 | |||||
>>> f(3) | |||||
4 | |||||
>>> f(3.0) | |||||
2.0 | |||||
Specify an isolated namespace with the namespace keyword argument | |||||
>>> my_namespace = dict() | |||||
>>> @dispatch(int, namespace=my_namespace) | |||||
... def foo(x): | |||||
... return x + 1 | |||||
Dispatch on instance methods within classes | |||||
>>> class MyClass(object): | |||||
... @dispatch(list) | |||||
... def __init__(self, data): | |||||
... self.data = data | |||||
... @dispatch(int) | |||||
... def __init__(self, datum): | |||||
... self.data = [datum] | |||||
""" | |||||
namespace = kwargs.get("namespace", global_namespace) | |||||
types = tuple(types) | |||||
def _df(func): | |||||
name = func.__name__ | |||||
if ismethod(func): | |||||
dispatcher = inspect.currentframe().f_back.f_locals.get( | |||||
name, MethodDispatcher(name), | |||||
) | |||||
else: | |||||
if name not in namespace: | |||||
namespace[name] = Dispatcher(name) | |||||
dispatcher = namespace[name] | |||||
dispatcher.add(types, func) | |||||
return dispatcher | |||||
return _df | |||||
def ismethod(func): | |||||
""" Is func a method? | |||||
Note that this has to work as the method is defined but before the class is | |||||
defined. At this stage methods look like functions. | |||||
""" | |||||
if hasattr(inspect, "signature"): | |||||
signature = inspect.signature(func) | |||||
return signature.parameters.get("self", None) is not None | |||||
else: | |||||
if sys.version_info.major < 3: | |||||
spec = inspect.getargspec(func) | |||||
else: | |||||
spec = inspect.getfullargspec(func) | |||||
return spec and spec.args and spec.args[0] == "self" |
@@ -1,445 +0,0 @@ | |||||
# Copyright (c) 2014 Matthew Rocklin | |||||
# | |||||
# All rights reserved. | |||||
# | |||||
# Redistribution and use in source and binary forms, with or without | |||||
# modification, are permitted provided that the following conditions are met: | |||||
# | |||||
# a. Redistributions of source code must retain the above copyright notice, | |||||
# this list of conditions and the following disclaimer. | |||||
# b. Redistributions in binary form must reproduce the above copyright | |||||
# notice, this list of conditions and the following disclaimer in the | |||||
# documentation and/or other materials provided with the distribution. | |||||
# c. Neither the name of multipledispatch nor the names of its contributors | |||||
# may be used to endorse or promote products derived from this software | |||||
# without specific prior written permission. | |||||
# | |||||
# | |||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR | |||||
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT | |||||
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY | |||||
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH | |||||
# DAMAGE. | |||||
# | |||||
# -------------------------------------------------------------------------------------- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# | |||||
# This file has been modified by Megvii ("Megvii Modifications"). | |||||
# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. | |||||
# -------------------------------------------------------------------------------------- | |||||
import copy | |||||
import inspect | |||||
import itertools as itl | |||||
from warnings import warn | |||||
from ..._imperative_rt.dispatcher import Dispatcher as CDispatcher | |||||
from .conflict import AmbiguityWarning, ambiguities, ordering, super_signature | |||||
from .utils import expand_tuples, parse_union | |||||
from .variadic import Variadic, isvariadic | |||||
def ambiguity_warn(dispatcher, ambiguities): | |||||
""" Raise warning when ambiguity is detected | |||||
Parameters | |||||
---------- | |||||
dispatcher : Dispatcher | |||||
The dispatcher on which the ambiguity was detected | |||||
ambiguities : set | |||||
Set of type signature pairs that are ambiguous within this dispatcher | |||||
See Also: | |||||
Dispatcher.add | |||||
warning_text | |||||
""" | |||||
warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning) | |||||
def variadic_signature_matches_iter(types, full_signature): | |||||
""" | |||||
Check if a set of input types matches a variadic signature. | |||||
Notes | |||||
----- | |||||
The algorithm is as follows: | |||||
Initialize the current signature to the first in the sequence | |||||
For each type in `types`: | |||||
If the current signature is variadic | |||||
If the type matches the signature | |||||
yield True | |||||
Else | |||||
Try to get the next signature | |||||
If no signatures are left we can't possibly have a match | |||||
so yield False | |||||
Else | |||||
yield True if the type matches the current signature | |||||
Get the next signature | |||||
""" | |||||
sigiter = iter(full_signature) | |||||
sig = next(sigiter) | |||||
for typ in types: | |||||
matches = issubclass(typ, sig) | |||||
yield matches | |||||
if not isvariadic(sig): | |||||
# we're not matching a variadic argument, so move to the next | |||||
# element in the signature | |||||
sig = next(sigiter) | |||||
else: | |||||
try: | |||||
sig = next(sigiter) | |||||
except StopIteration: | |||||
assert isvariadic(sig) | |||||
yield True | |||||
else: | |||||
# We have signature items left over, so all of our arguments | |||||
# haven't matched | |||||
yield False | |||||
def variadic_signature_matches(types, full_signature): | |||||
# No arguments always matches a variadic signature | |||||
assert full_signature | |||||
return all(variadic_signature_matches_iter(types, full_signature)) | |||||
def get_func_signature(function): | |||||
sig = inspect.signature(function) | |||||
types = [] | |||||
for p in sig.parameters.values(): | |||||
ann = p.annotation | |||||
ann = parse_union(ann) or ann | |||||
if p.kind in ( | |||||
inspect.Parameter.POSITIONAL_ONLY, | |||||
inspect.Parameter.POSITIONAL_OR_KEYWORD, | |||||
): | |||||
types.append(ann) | |||||
if p.kind == inspect.Parameter.VAR_POSITIONAL: | |||||
types.append([ann]) | |||||
return tuple(types) | |||||
class Frame: | |||||
__slots__ = "args", "types", "mro", "mro_offset" | |||||
class Dispatcher(CDispatcher): | |||||
""" Dispatch methods based on type signature | |||||
Use ``dispatch`` to add implementations | |||||
Examples | |||||
-------- | |||||
>>> from multipledispatch import dispatch | |||||
>>> @dispatch(int) | |||||
... def f(x): | |||||
... return x + 1 | |||||
>>> @dispatch(float) | |||||
... def f(x): | |||||
... return x - 1 | |||||
>>> f(3) | |||||
4 | |||||
>>> f(3.0) | |||||
2.0 | |||||
""" | |||||
__slots__ = "__name__", "name", "funcs", "_ordering", "doc" | |||||
def __init__(self, name, doc=None): | |||||
self.name = self.__name__ = name | |||||
self.funcs = {} | |||||
self.doc = doc | |||||
def register(self, *types, **kwargs): | |||||
""" register dispatcher with new implementation | |||||
>>> f = Dispatcher('f') | |||||
>>> @f.register(int) | |||||
... def inc(x): | |||||
... return x + 1 | |||||
>>> @f.register(float) | |||||
... def dec(x): | |||||
... return x - 1 | |||||
>>> @f.register(list) | |||||
... @f.register(tuple) | |||||
... def reverse(x): | |||||
... return x[::-1] | |||||
>>> f(1) | |||||
2 | |||||
>>> f(1.0) | |||||
0.0 | |||||
>>> f([1, 2, 3]) | |||||
[3, 2, 1] | |||||
""" | |||||
def _df(func): | |||||
self.add(types, func, **kwargs) | |||||
return func | |||||
return _df | |||||
def add(self, signature, func): | |||||
""" Add new types/method pair to dispatcher | |||||
>>> D = Dispatcher('add') | |||||
>>> D.add((int, int), lambda x, y: x + y) | |||||
>>> D.add((float, float), lambda x, y: x + y) | |||||
>>> D(1, 2) | |||||
3 | |||||
>>> D(1, 2.0) | |||||
Traceback (most recent call last): | |||||
... | |||||
NotImplementedError: Could not find signature for add: <int, float> | |||||
When ``add`` detects a warning it calls the ``on_ambiguity`` callback | |||||
with a dispatcher/itself, and a set of ambiguous type signature pairs | |||||
as inputs. See ``ambiguity_warn`` for an example. | |||||
""" | |||||
# Handle annotations | |||||
if not signature: | |||||
signature = get_func_signature(func) | |||||
# Handle union types | |||||
if any(isinstance(typ, tuple) for typ in signature): | |||||
for typs in expand_tuples(signature): | |||||
self.add(typs, func) | |||||
return | |||||
new_signature = [] | |||||
for index, typ in enumerate(signature, start=1): | |||||
if not isinstance(typ, (type, list)): | |||||
str_sig = ", ".join( | |||||
c.__name__ if isinstance(c, type) else str(c) for c in signature | |||||
) | |||||
raise TypeError( | |||||
"Tried to dispatch on non-type: %s\n" | |||||
"In signature: <%s>\n" | |||||
"In function: %s" % (typ, str_sig, self.name) | |||||
) | |||||
# handle variadic signatures | |||||
if isinstance(typ, list): | |||||
if index != len(signature): | |||||
raise TypeError("Variadic signature must be the last element") | |||||
if len(typ) != 1: | |||||
raise TypeError( | |||||
"Variadic signature must contain exactly one element. " | |||||
"To use a variadic union type place the desired types " | |||||
"inside of a tuple, e.g., [(int, str)]" | |||||
) | |||||
new_signature.append(Variadic[typ[0]]) | |||||
else: | |||||
new_signature.append(typ) | |||||
l = self.funcs.setdefault(tuple(new_signature), []) | |||||
for i in l: | |||||
if i is func: | |||||
raise ValueError("already registered") | |||||
l.append(func) | |||||
self.enable(func) | |||||
self.clear_cache() | |||||
try: | |||||
del self._ordering | |||||
except AttributeError: | |||||
pass | |||||
@property | |||||
def ordering(self): | |||||
try: | |||||
return self._ordering | |||||
except AttributeError: | |||||
return self.reorder() | |||||
def reorder(self, on_ambiguity=ambiguity_warn): | |||||
self._ordering = od = ordering(self.funcs) | |||||
amb = ambiguities(self.funcs) | |||||
if amb: | |||||
on_ambiguity(self, amb) | |||||
return od | |||||
def __str__(self): | |||||
return "<dispatched %s>" % self.name | |||||
__repr__ = __str__ | |||||
def dispatch(self, *types): | |||||
""" | |||||
Deterimine appropriate implementation for this type signature | |||||
This method is internal. Users should call this object as a function. | |||||
Implementation resolution occurs within the ``__call__`` method. | |||||
>>> from multipledispatch import dispatch | |||||
>>> @dispatch(int) | |||||
... def inc(x): | |||||
... return x + 1 | |||||
>>> implementation = inc.dispatch(int) | |||||
>>> implementation(3) | |||||
4 | |||||
>>> print(inc.dispatch(float)) | |||||
None | |||||
See Also: | |||||
``multipledispatch.conflict`` - module to determine resolution order | |||||
""" | |||||
if types in self.funcs: | |||||
return self.funcs[types][-1] | |||||
for f in self.dispatch_iter(*types): | |||||
return f | |||||
def dispatch_iter(self, *types): | |||||
n = len(types) | |||||
for signature in self.ordering: | |||||
if ( | |||||
len(signature) == n | |||||
and all(map(issubclass, types, signature)) | |||||
or len(signature) | |||||
and isvariadic(signature[-1]) | |||||
and variadic_signature_matches(types, signature) | |||||
): | |||||
yield from self.funcs[signature][::-1] | |||||
def __getstate__(self): | |||||
return {"name": self.name, "funcs": self.funcs} | |||||
def __setstate__(self, d): | |||||
self.name = d["name"] | |||||
self.funcs = d["funcs"] | |||||
self._ordering = ordering(self.funcs) | |||||
self._cache = dict() | |||||
@property | |||||
def __doc__(self): | |||||
docs = ["Multiply dispatched method: %s" % self.name] | |||||
if self.doc: | |||||
docs.append(self.doc) | |||||
other = [] | |||||
for sig in self.ordering[::-1]: | |||||
funcs = self.funcs[sig] | |||||
s = "Inputs: <%s>\n" % str_signature(sig) | |||||
sep = "-" * len(s) + "\n" | |||||
for i, func in enumerate(funcs): | |||||
s += sep | |||||
if len(funcs) > 1: | |||||
s += "[Handler %d]\n\n" % (i + 1) | |||||
if i: | |||||
s += "\n\n" | |||||
if func.__doc__: | |||||
s += func.__doc__.strip() | |||||
else: | |||||
s += repr(func) + "\n" | |||||
docs.append(s) | |||||
return "\n\n".join(docs) | |||||
def _help(self, *args): | |||||
return self.dispatch(*map(type, args)).__doc__ | |||||
def help(self, *args, **kwargs): | |||||
""" Print docstring for the function corresponding to inputs """ | |||||
print(self._help(*args)) | |||||
def _source(self, *args): | |||||
func = self.dispatch(*map(type, args)) | |||||
if not func: | |||||
raise TypeError("No function found") | |||||
return source(func) | |||||
def source(self, *args, **kwargs): | |||||
""" Print source code for the function corresponding to inputs """ | |||||
print(self._source(*args)) | |||||
def source(func): | |||||
s = "File: %s\n\n" % inspect.getsourcefile(func) | |||||
s = s + inspect.getsource(func) | |||||
return s | |||||
class MethodDispatcher(Dispatcher): | |||||
""" Dispatch methods based on type signature | |||||
See Also: | |||||
Dispatcher | |||||
""" | |||||
__slots__ = ("obj", "cls") | |||||
@classmethod | |||||
def get_func_params(cls, func): | |||||
if hasattr(inspect, "signature"): | |||||
sig = inspect.signature(func) | |||||
return itl.islice(sig.parameters.values(), 1, None) | |||||
def __get__(self, instance, owner): | |||||
self.obj = instance | |||||
self.cls = owner | |||||
return self | |||||
def __call__(self, *args, **kwargs): | |||||
types = tuple([type(arg) for arg in args]) | |||||
func = self.dispatch(*types) | |||||
if not func: | |||||
raise NotImplementedError( | |||||
"Could not find signature for %s: <%s>" | |||||
% (self.name, str_signature(types)) | |||||
) | |||||
return func(self.obj, *args, **kwargs) | |||||
def str_signature(sig): | |||||
""" String representation of type signature | |||||
>>> str_signature((int, float)) | |||||
'int, float' | |||||
""" | |||||
return ", ".join(cls.__name__ for cls in sig) | |||||
def warning_text(name, amb): | |||||
""" The text for ambiguity warnings """ | |||||
text = "\nAmbiguities exist in dispatched function %s\n\n" % (name) | |||||
text += "The following signatures may result in ambiguous behavior:\n" | |||||
for pair in amb: | |||||
text += "\t" + ", ".join("[" + str_signature(s) + "]" for s in pair) + "\n" | |||||
text += "\n\nConsider making the following additions:\n\n" | |||||
text += "\n\n".join( | |||||
[ | |||||
"@dispatch(" + str_signature(super_signature(s)) + ")\ndef %s(...)" % name | |||||
for s in amb | |||||
] | |||||
) | |||||
return text |
@@ -1,210 +0,0 @@ | |||||
# Copyright (c) 2014 Matthew Rocklin | |||||
# | |||||
# All rights reserved. | |||||
# | |||||
# Redistribution and use in source and binary forms, with or without | |||||
# modification, are permitted provided that the following conditions are met: | |||||
# | |||||
# a. Redistributions of source code must retain the above copyright notice, | |||||
# this list of conditions and the following disclaimer. | |||||
# b. Redistributions in binary form must reproduce the above copyright | |||||
# notice, this list of conditions and the following disclaimer in the | |||||
# documentation and/or other materials provided with the distribution. | |||||
# c. Neither the name of multipledispatch nor the names of its contributors | |||||
# may be used to endorse or promote products derived from this software | |||||
# without specific prior written permission. | |||||
# | |||||
# | |||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR | |||||
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT | |||||
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY | |||||
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH | |||||
# DAMAGE. | |||||
# | |||||
# -------------------------------------------------------------------------------------- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# | |||||
# This file has been modified by Megvii ("Megvii Modifications"). | |||||
# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. | |||||
# -------------------------------------------------------------------------------------- | |||||
import sys | |||||
import typing | |||||
from collections import OrderedDict | |||||
def raises(err, lamda): | |||||
try: | |||||
lamda() | |||||
return False | |||||
except err: | |||||
return True | |||||
def expand_tuples(L): | |||||
""" | |||||
>>> expand_tuples([1, (2, 3)]) | |||||
[(1, 2), (1, 3)] | |||||
>>> expand_tuples([1, 2]) | |||||
[(1, 2)] | |||||
""" | |||||
if not L: | |||||
return [()] | |||||
elif not isinstance(L[0], tuple): | |||||
rest = expand_tuples(L[1:]) | |||||
return [(L[0],) + t for t in rest] | |||||
else: | |||||
rest = expand_tuples(L[1:]) | |||||
return [(item,) + t for t in rest for item in L[0]] | |||||
# Taken from theano/theano/gof/sched.py | |||||
# Avoids licensing issues because this was written by Matthew Rocklin | |||||
def _toposort(edges): | |||||
""" Topological sort algorithm by Kahn [1] - O(nodes + vertices) | |||||
inputs: | |||||
edges - a dict of the form {a: {b, c}} where b and c depend on a | |||||
outputs: | |||||
L - an ordered list of nodes that satisfy the dependencies of edges | |||||
>>> _toposort({1: (2, 3), 2: (3, )}) | |||||
[1, 2, 3] | |||||
Closely follows the wikipedia page [2] | |||||
[1] Kahn, Arthur B. (1962), "Topological sorting of large networks", | |||||
Communications of the ACM | |||||
[2] http://en.wikipedia.org/wiki/Toposort#Algorithms | |||||
""" | |||||
incoming_edges = reverse_dict(edges) | |||||
incoming_edges = OrderedDict((k, set(val)) for k, val in incoming_edges.items()) | |||||
S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges) | |||||
L = [] | |||||
while S: | |||||
n, _ = S.popitem() | |||||
L.append(n) | |||||
for m in edges.get(n, ()): | |||||
assert n in incoming_edges[m] | |||||
incoming_edges[m].remove(n) | |||||
if not incoming_edges[m]: | |||||
S[m] = None | |||||
if any(incoming_edges.get(v, None) for v in edges): | |||||
raise ValueError("Input has cycles") | |||||
return L | |||||
def reverse_dict(d): | |||||
""" | |||||
Reverses direction of dependence dict | |||||
>>> d = {'a': (1, 2), 'b': (2, 3), 'c':()} | |||||
>>> reverse_dict(d) # doctest: +SKIP | |||||
{1: ('a',), 2: ('a', 'b'), 3: ('b',)} | |||||
:note: dict order are not deterministic. As we iterate on the | |||||
input dict, it make the output of this function depend on the | |||||
dict order. So this function output order should be considered | |||||
as undeterministic. | |||||
""" | |||||
result = OrderedDict() | |||||
for key in d: | |||||
for val in d[key]: | |||||
result[val] = result.get(val, tuple()) + (key,) | |||||
return result | |||||
# Taken from toolz | |||||
# Avoids licensing issues because this version was authored by Matthew Rocklin | |||||
def groupby(func, seq): | |||||
""" Group a collection by a key function | |||||
>>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank'] | |||||
>>> groupby(len, names) # doctest: +SKIP | |||||
{3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']} | |||||
>>> iseven = lambda x: x % 2 == 0 | |||||
>>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP | |||||
{False: [1, 3, 5, 7], True: [2, 4, 6, 8]} | |||||
See Also: | |||||
``countby`` | |||||
""" | |||||
d = OrderedDict() | |||||
for item in seq: | |||||
key = func(item) | |||||
if key not in d: | |||||
d[key] = list() | |||||
d[key].append(item) | |||||
return d | |||||
def typename(type): | |||||
""" | |||||
Get the name of `type`. | |||||
Parameters | |||||
---------- | |||||
type : Union[Type, Tuple[Type]] | |||||
Returns | |||||
------- | |||||
str | |||||
The name of `type` or a tuple of the names of the types in `type`. | |||||
Examples | |||||
-------- | |||||
>>> typename(int) | |||||
'int' | |||||
>>> typename((int, float)) | |||||
'(int, float)' | |||||
""" | |||||
try: | |||||
return type.__name__ | |||||
except AttributeError: | |||||
if len(type) == 1: | |||||
return typename(*type) | |||||
return "(%s)" % ", ".join(map(typename, type)) | |||||
# parse typing.Union | |||||
def parse_union(ann): | |||||
if hasattr(typing, "UnionMeta"): | |||||
if type(ann) is not typing.UnionMeta: | |||||
return | |||||
return ann.__union_params__ | |||||
elif hasattr(typing, "_Union"): | |||||
if type(ann) is not typing._Union: | |||||
return | |||||
return ann.__args__ | |||||
elif hasattr(typing, "_GenericAlias"): | |||||
if type(ann) is not typing._GenericAlias: | |||||
if type(ann) is not typing.Union: | |||||
return | |||||
else: | |||||
if ann.__origin__ is not typing.Union: | |||||
return | |||||
return ann.__args__ | |||||
elif hasattr(typing, "Union"): | |||||
if typing.get_origin(ann) is not typing.Union: | |||||
return | |||||
return typing.get_args(ann) | |||||
else: | |||||
raise NotImplementedError("unsupported Python version") |
@@ -1,140 +0,0 @@ | |||||
# Copyright (c) 2014 Matthew Rocklin | |||||
# | |||||
# All rights reserved. | |||||
# | |||||
# Redistribution and use in source and binary forms, with or without | |||||
# modification, are permitted provided that the following conditions are met: | |||||
# | |||||
# a. Redistributions of source code must retain the above copyright notice, | |||||
# this list of conditions and the following disclaimer. | |||||
# b. Redistributions in binary form must reproduce the above copyright | |||||
# notice, this list of conditions and the following disclaimer in the | |||||
# documentation and/or other materials provided with the distribution. | |||||
# c. Neither the name of multipledispatch nor the names of its contributors | |||||
# may be used to endorse or promote products derived from this software | |||||
# without specific prior written permission. | |||||
# | |||||
# | |||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR | |||||
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT | |||||
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY | |||||
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH | |||||
# DAMAGE. | |||||
# | |||||
# -------------------------------------------------------------------------------------- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# | |||||
# This file has been modified by Megvii ("Megvii Modifications"). | |||||
# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. | |||||
# -------------------------------------------------------------------------------------- | |||||
from .utils import typename | |||||
class VariadicSignatureType(type): | |||||
# checking if subclass is a subclass of self | |||||
def __subclasscheck__(self, subclass): | |||||
other_type = subclass.variadic_type if isvariadic(subclass) else (subclass,) | |||||
return subclass is self or all( | |||||
issubclass(other, self.variadic_type) for other in other_type | |||||
) | |||||
def __eq__(self, other): | |||||
""" | |||||
Return True if other has the same variadic type | |||||
Parameters | |||||
---------- | |||||
other : object (type) | |||||
The object (type) to check | |||||
Returns | |||||
------- | |||||
bool | |||||
Whether or not `other` is equal to `self` | |||||
""" | |||||
return isvariadic(other) and set(self.variadic_type) == set(other.variadic_type) | |||||
def __hash__(self): | |||||
return hash((type(self), frozenset(self.variadic_type))) | |||||
def isvariadic(obj): | |||||
""" | |||||
Check whether the type `obj` is variadic. | |||||
Parameters | |||||
---------- | |||||
obj : type | |||||
The type to check | |||||
Returns | |||||
------- | |||||
bool | |||||
Whether or not `obj` is variadic | |||||
Examples | |||||
-------- | |||||
>>> isvariadic(int) | |||||
False | |||||
>>> isvariadic(Variadic[int]) | |||||
True | |||||
""" | |||||
return isinstance(obj, VariadicSignatureType) | |||||
class VariadicSignatureMeta(type): | |||||
""" | |||||
A metaclass that overrides ``__getitem__`` on the class. This is used to | |||||
generate a new type for Variadic signatures. See the Variadic class for | |||||
examples of how this behaves. | |||||
""" | |||||
def __getitem__(self, variadic_type): | |||||
if not (isinstance(variadic_type, (type, tuple)) or type(variadic_type)): | |||||
raise ValueError( | |||||
"Variadic types must be type or tuple of types" | |||||
" (Variadic[int] or Variadic[(int, float)]" | |||||
) | |||||
if not isinstance(variadic_type, tuple): | |||||
variadic_type = (variadic_type,) | |||||
return VariadicSignatureType( | |||||
"Variadic[%s]" % typename(variadic_type), | |||||
(), | |||||
dict(variadic_type=variadic_type, __slots__=()), | |||||
) | |||||
class Variadic(metaclass=VariadicSignatureMeta): | |||||
""" | |||||
A class whose getitem method can be used to generate a new type | |||||
representing a specific variadic signature. | |||||
Examples | |||||
-------- | |||||
>>> Variadic[int] # any number of int arguments | |||||
<class 'multipledispatch.variadic.Variadic[int]'> | |||||
>>> Variadic[(int, str)] # any number of one of int or str arguments | |||||
<class 'multipledispatch.variadic.Variadic[(int, str)]'> | |||||
>>> issubclass(int, Variadic[int]) | |||||
True | |||||
>>> issubclass(int, Variadic[(int, str)]) | |||||
True | |||||
>>> issubclass(str, Variadic[(int, str)]) | |||||
True | |||||
>>> issubclass(float, Variadic[(int, str)]) | |||||
False | |||||
""" |
@@ -1,136 +0,0 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import functools | |||||
import numpy as np | |||||
from ..._imperative_rt import CompNode, DeviceTensorND | |||||
from ..._imperative_rt.imperative import ( | |||||
_drop, | |||||
_get_dev_tensor, | |||||
_swap_in, | |||||
_swap_out, | |||||
apply_op, | |||||
delete, | |||||
get_device, | |||||
get_dtype, | |||||
get_shape, | |||||
get_value, | |||||
put, | |||||
) | |||||
from ..._wrap import device as as_device | |||||
from ...ops.builtin import Copy, OpDef, TypeCvt | |||||
from ...ops.special import Const | |||||
from ..core import OpBase, TensorBase, apply | |||||
class RawTensor(TensorBase): | |||||
_init_cb = None | |||||
_del_cb = None | |||||
_handle = None | |||||
def __init__(self, handle=None, isscalar=False): | |||||
self._handle = handle | |||||
self._isscalar = isscalar | |||||
if handle is not None: | |||||
if self._init_cb: | |||||
self._init_cb() | |||||
@property | |||||
def dtype(self): | |||||
return get_dtype(self._handle) | |||||
@property | |||||
def device(self): | |||||
return as_device(get_device(self._handle)) | |||||
@property | |||||
def shape(self): | |||||
if self._isscalar: | |||||
return () | |||||
return get_shape(self._handle) | |||||
def numpy(self): | |||||
ret = get_value(self._handle) | |||||
if self._isscalar: | |||||
ret = ret.squeeze() | |||||
return ret | |||||
def _dev_tensor(self): | |||||
return _get_dev_tensor(self._handle) | |||||
def _drop(self): | |||||
_drop(self._handle) | |||||
def _swap_in(self): | |||||
_swap_in(self._handle) | |||||
def _swap_out(self): | |||||
_swap_out(self._handle) | |||||
def __repr__(self): | |||||
return "{}({}, device='{}')".format( | |||||
type(self).__qualname__, repr(self.numpy()), self.device | |||||
) | |||||
def __del__(self): | |||||
if self._handle is not None: | |||||
if self._del_cb: | |||||
self._del_cb() | |||||
delete(self._handle) | |||||
@apply.register() | |||||
def _(op: OpDef, *args: RawTensor): | |||||
outputs = apply_op(op, tuple(i._handle for i in args)) | |||||
return tuple(map(RawTensor, outputs)) | |||||
@apply.register() | |||||
def _(op: Const, *args: RawTensor): | |||||
dtype = op.dtype | |||||
device = as_device(op.device).to_c() | |||||
return (as_raw_tensor(op.value, dtype=dtype, device=device),) | |||||
@functools.singledispatch | |||||
def as_raw_tensor(obj, dtype=None, device=None): | |||||
obj = np.asarray(obj, dtype=dtype) | |||||
if obj.dtype == np.float64: | |||||
obj = obj.astype(np.float32) | |||||
if obj.dtype == np.int64: | |||||
obj = obj.astype(np.int32) | |||||
return as_raw_tensor(obj, device=device) | |||||
@as_raw_tensor.register(DeviceTensorND) | |||||
def _(data: DeviceTensorND): | |||||
return RawTensor(put(data)) | |||||
@as_raw_tensor.register(np.ndarray) | |||||
def _(array: np.ndarray, dtype=None, device=None): | |||||
device = None if device is None else as_device(device).to_c() | |||||
if 0 in array.strides: | |||||
array = array.squeeze().reshape(array.shape) | |||||
return RawTensor(put(array, dtype=dtype, device=device), isscalar=(array.ndim == 0)) | |||||
@as_raw_tensor.register(RawTensor) | |||||
def _(tensor: RawTensor, dtype=None, device=None): | |||||
if dtype is not None: | |||||
dtype = np.dtype(dtype) | |||||
if dtype != tensor.dtype: | |||||
(tensor,) = apply(TypeCvt(dtype=dtype), tensor) | |||||
if device is not None: | |||||
device = as_device(device) | |||||
if device != tensor.device: | |||||
(tensor,) = apply(Copy(comp_node=device.to_c()), tensor) | |||||
return tensor |
@@ -9,14 +9,7 @@ | |||||
from typing import Optional, Tuple | from typing import Optional, Tuple | ||||
from ..core._imperative_rt.core2 import apply | from ..core._imperative_rt.core2 import apply | ||||
from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn | |||||
from ..core.autodiff.grad import ( | |||||
Tracer, | |||||
check_backward_allow_noinput, | |||||
get_grad_managers, | |||||
get_op_has_grad_fn, | |||||
tracer_apply, | |||||
) | |||||
from ..core.autodiff.grad import get_grad_managers | |||||
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | ||||
from ..device import get_default_device | from ..device import get_default_device | ||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
@@ -236,7 +229,7 @@ def remote_recv( | |||||
device = get_default_device() | device = get_default_device() | ||||
# dummy input | # dummy input | ||||
if inp == None: | if inp == None: | ||||
inp = tensor([0], device=device) | |||||
inp = Tensor([0], device=device) | |||||
tracer_set = get_client().check_remote_tracer(key) | tracer_set = get_client().check_remote_tracer(key) | ||||
for grad_manager in get_grad_managers(): | for grad_manager in get_grad_managers(): | ||||
if grad_manager.name in tracer_set: | if grad_manager.name in tracer_set: | ||||
@@ -67,7 +67,7 @@ def param_pack_split(inp: Tensor, offsets: list, shapes: list): | |||||
outputs = apply(op, inp) | outputs = apply(op, inp) | ||||
for s, x in zip(shapes, outputs): | for s, x in zip(shapes, outputs): | ||||
if not s: | if not s: | ||||
x._isscalar = True | |||||
x.setscalar() | |||||
return outputs | return outputs | ||||
@@ -10,7 +10,7 @@ | |||||
from typing import Optional, Sequence, Tuple, Union | from typing import Optional, Sequence, Tuple, Union | ||||
from ..core._imperative_rt import CompNode | from ..core._imperative_rt import CompNode | ||||
from ..core._imperative_rt.core2 import Tensor, apply | |||||
from ..core._imperative_rt.core2 import apply | |||||
from ..core._trace_option import use_symbolic_shape | from ..core._trace_option import use_symbolic_shape | ||||
from ..core.ops import builtin | from ..core.ops import builtin | ||||
from ..core.ops.builtin import BatchNorm | from ..core.ops.builtin import BatchNorm | ||||
@@ -12,10 +12,10 @@ from typing import Dict | |||||
import numpy as np | import numpy as np | ||||
from .. import functional as F | from .. import functional as F | ||||
from ..core._imperative_rt.core2 import apply | |||||
from ..core.autodiff.grad import Function | from ..core.autodiff.grad import Function | ||||
from ..core.ops import builtin | from ..core.ops import builtin | ||||
from ..core.tensor import megbrain_graph | from ..core.tensor import megbrain_graph | ||||
from ..core.tensor.core import apply | |||||
from ..core.tensor.dtype import _metadata_dict | from ..core.tensor.dtype import _metadata_dict | ||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
@@ -3,7 +3,7 @@ import sys | |||||
import pytest | import pytest | ||||
from megengine.core._imperative_rt.imperative import sync | |||||
from megengine.core._imperative_rt.core2 import sync | |||||
sys.path.append(os.path.join(os.path.dirname(__file__), "helpers")) | sys.path.append(os.path.join(os.path.dirname(__file__), "helpers")) | ||||
@@ -4,7 +4,6 @@ import megengine as mge | |||||
import megengine.autodiff as ad | import megengine.autodiff as ad | ||||
import megengine.optimizer as optimizer | import megengine.optimizer as optimizer | ||||
from megengine import Parameter, tensor | from megengine import Parameter, tensor | ||||
from megengine.core.tensor.raw_tensor import RawTensor | |||||
from megengine.module import Module | from megengine.module import Module | ||||
@@ -13,7 +13,6 @@ import pytest | |||||
import megengine.core.tensor.megbrain_graph as G | import megengine.core.tensor.megbrain_graph as G | ||||
from megengine.core.ops import builtin as ops | from megengine.core.ops import builtin as ops | ||||
from megengine.core.tensor.core import apply | |||||
from megengine.core.tensor.dtype import ( | from megengine.core.tensor.dtype import ( | ||||
_metadata_dict, | _metadata_dict, | ||||
convert_from_qint4, | convert_from_qint4, | ||||
@@ -1,58 +0,0 @@ | |||||
from megengine.core.tensor.multipledispatch import Dispatcher | |||||
def test_register_many(): | |||||
f = Dispatcher("f") | |||||
log = [] | |||||
@f.register() | |||||
def _(x: int): | |||||
log.append("a") | |||||
return log[-1] | |||||
@f.register() | |||||
def _(x: int): | |||||
log.append("b") | |||||
return log[-1] | |||||
assert f(0) == "b" | |||||
assert log == ["b"] | |||||
def test_return_not_implemented(): | |||||
f = Dispatcher("f") | |||||
log = [] | |||||
@f.register() | |||||
def _(x: int): | |||||
log.append("a") | |||||
return log[-1] | |||||
@f.register() | |||||
def _(x: int): | |||||
log.append("b") | |||||
return NotImplemented | |||||
assert f(0) == "a" | |||||
assert log == ["b", "a"] | |||||
def test_super(): | |||||
f = Dispatcher("f") | |||||
log = [] | |||||
@f.register() | |||||
def _(x: int): | |||||
log.append("a") | |||||
return log[-1] | |||||
@f.register() | |||||
def _(x: int): | |||||
log.append("b") | |||||
return f.super(x) | |||||
assert f(0) == "a" | |||||
assert log == ["b", "a"] |