Browse Source

refactor(mge/imperative): fork multipledispatch

GitOrigin-RevId: a7c25a4302
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
5474538f39
21 changed files with 1461 additions and 82 deletions
  1. +38
    -0
      ACKNOWLEDGMENTS
  2. +2
    -2
      imperative/python/megengine/core/autodiff/grad.py
  3. +1
    -1
      imperative/python/megengine/core/ops/builtin/__init__.py
  4. +3
    -70
      imperative/python/megengine/core/tensor/core.py
  5. +3
    -3
      imperative/python/megengine/core/tensor/function.py
  6. +1
    -1
      imperative/python/megengine/core/tensor/megbrain_graph.py
  7. +10
    -0
      imperative/python/megengine/core/tensor/multipledispatch/__init__.py
  8. +121
    -0
      imperative/python/megengine/core/tensor/multipledispatch/conflict.py
  9. +88
    -0
      imperative/python/megengine/core/tensor/multipledispatch/core.py
  10. +401
    -0
      imperative/python/megengine/core/tensor/multipledispatch/dispatcher.py
  11. +177
    -0
      imperative/python/megengine/core/tensor/multipledispatch/utils.py
  12. +95
    -0
      imperative/python/megengine/core/tensor/multipledispatch/variadic.py
  13. +2
    -2
      imperative/python/megengine/core/tensor/raw_tensor/__init__.py
  14. +1
    -1
      imperative/python/megengine/core/tensor/tensor.py
  15. +1
    -1
      imperative/python/megengine/functional/distributed.py
  16. +0
    -1
      imperative/python/requires.txt
  17. +180
    -0
      imperative/python/src/dispatcher.cpp
  18. +5
    -0
      imperative/python/src/dispatcher.h
  19. +4
    -0
      imperative/python/src/module.cpp
  20. +270
    -0
      imperative/python/src/pyext17.h
  21. +58
    -0
      imperative/python/test/unit/test_dispatch.py

+ 38
- 0
ACKNOWLEDGMENTS View File

@@ -210,6 +210,44 @@ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND


*********************************************************************************************************************************
multipledispatch
--------------------------------------------------------------------
Copyright (c) 2014 Matthew Rocklin

All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

a. Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
b. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
c. Neither the name of multipledispatch nor the names of its contributors
may be used to endorse or promote products derived from this software
without specific prior written permission.


THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
DAMAGE.
*********************************************************************************************************************************






*********************************************************************************************************************************
Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-party Components therein:
--------------------------------------------------------------------
protobuf


+ 2
- 2
imperative/python/megengine/core/autodiff/grad.py View File

@@ -343,7 +343,7 @@ def default_has_grad_fn(opnode, reached):
return False


@apply.add
@apply.register()
def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]):
args = tuple(i if isinstance(i, Tracer) else None for i in args)
input_requires_grad = list(map(bool, args))
@@ -385,6 +385,6 @@ def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]):
return tuple(outputs)


@apply.add
@apply.register()
def _(op: Const, *_: typing.Optional[Tracer]):
return None

+ 1
- 1
imperative/python/megengine/core/ops/builtin/__init__.py View File

@@ -19,7 +19,7 @@ from .._internal.helper import PodOpVisitor
OpBase.register(OpDef)

# forward to apply(OpDef, ...)
@apply.add
@apply.register()
def _(op: PodOpVisitor, *args: Union[TensorBase, TensorWrapperBase]):
return apply(op.to_c(), *args)



+ 3
- 70
imperative/python/megengine/core/tensor/core.py View File

@@ -13,7 +13,7 @@ import sys
import typing
from abc import ABC

import multipledispatch
from .multipledispatch import Dispatcher


class OpBase(ABC):
@@ -29,84 +29,17 @@ class TensorWrapperBase:
pass


class Dispatcher(multipledispatch.Dispatcher):
def add(self, f, g=None):
if g is None:
super().add(get_signature(f), f)
else:
super().add(f, g)

return f

def __get__(self, instance, owner=None):
if instance is not None:
return self
return functools.partial(self, instance)


if sys.version_info < (3, 6):

def parse_union(ann):
if type(ann) is not typing.UnionMeta:
return
return ann.__union_params__


elif sys.version_info < (3, 7):

def parse_union(ann):
if type(ann) is not typing._Union:
return
return ann.__args__


elif sys.version_info < (3, 8):

def parse_union(ann):
if type(ann) is not typing._GenericAlias:
if type(ann) is not typing.Union:
return
else:
if ann.__origin__ is not typing.Union:
return
return ann.__args__


else:

def parse_union(ann):
if typing.get_origin(ann) is not typing.Union:
return
return typing.get_args(ann)


def get_signature(function, op_type=None):
sig = inspect.signature(function)
types = []
for p in sig.parameters.values():
ann = p.annotation
ann = parse_union(ann) or ann
if p.kind in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
):
types.append(ann)
if p.kind == inspect.Parameter.VAR_POSITIONAL:
types.append([ann])
return tuple(types)


apply = Dispatcher("apply")

OpBase.apply = apply


@apply.add
@apply.register()
def _(op: OpBase, *args: TensorBase):
raise NotImplementedError


@apply.add
@apply.register()
def _(op: OpBase, *args: TensorWrapperBase):
assert args
Wrapper = type(args[0])


+ 3
- 3
imperative/python/megengine/core/tensor/function.py View File

@@ -102,7 +102,7 @@ class Function:
Function.apply = Function.__call__


@apply.add
@apply.register()
def _(op: Function, *args: TensorWrapperBase):
assert args
Wrapper = type(args[0])
@@ -148,11 +148,11 @@ def _(op: Function, *args: TensorWrapperBase):
return tuple(map(Wrapper, outputs))


@apply.add
@apply.register()
def _(op: Function, *args: Tensor):
raise NotImplementedError


@apply.add
@apply.register()
def _(op: Function, *args: RawTensor):
raise NotImplementedError

+ 1
- 1
imperative/python/megengine/core/tensor/megbrain_graph.py View File

@@ -111,7 +111,7 @@ def _unwrap(x):
return x._node


@apply.add
@apply.register()
def _(op: OpDef, *args: VarNode):
outputs = _imperative_rt.invoke_op(op, _unwrap(args))
return _wrap(outputs)


+ 10
- 0
imperative/python/megengine/core/tensor/multipledispatch/__init__.py View File

@@ -0,0 +1,10 @@
# This directory is a fork of multipledispatch.
#
# Repo: https://github.com/mrocklin/multipledispatch
# Commit: 9e3c87d0cee57972fd5cc33fe5cacde77c781834
# Authors: Matthew Rocklin et al.
#
# Refer to ACKNOWLEDGEMENT for copyright and liscense information

from .core import dispatch
from .dispatcher import Dispatcher

+ 121
- 0
imperative/python/megengine/core/tensor/multipledispatch/conflict.py View File

@@ -0,0 +1,121 @@
from .utils import _toposort, groupby
from .variadic import isvariadic


class AmbiguityWarning(Warning):
pass


def supercedes(a, b):
""" A is consistent and strictly more specific than B """
if len(a) < len(b):
# only case is if a is empty and b is variadic
return not a and len(b) == 1 and isvariadic(b[-1])
elif len(a) == len(b):
return all(map(issubclass, a, b))
else:
# len(a) > len(b)
p1 = 0
p2 = 0
while p1 < len(a) and p2 < len(b):
cur_a = a[p1]
cur_b = b[p2]
if not (isvariadic(cur_a) or isvariadic(cur_b)):
if not issubclass(cur_a, cur_b):
return False
p1 += 1
p2 += 1
elif isvariadic(cur_a):
assert p1 == len(a) - 1
return p2 == len(b) - 1 and issubclass(cur_a, cur_b)
elif isvariadic(cur_b):
assert p2 == len(b) - 1
if not issubclass(cur_a, cur_b):
return False
p1 += 1
return p2 == len(b) - 1 and p1 == len(a)


def consistent(a, b):
""" It is possible for an argument list to satisfy both A and B """

# Need to check for empty args
if not a:
return not b or isvariadic(b[0])
if not b:
return not a or isvariadic(a[0])

# Non-empty args check for mutual subclasses
if len(a) == len(b):
return all(issubclass(aa, bb) or issubclass(bb, aa) for aa, bb in zip(a, b))
else:
p1 = 0
p2 = 0
while p1 < len(a) and p2 < len(b):
cur_a = a[p1]
cur_b = b[p2]
if not issubclass(cur_b, cur_a) and not issubclass(cur_a, cur_b):
return False
if not (isvariadic(cur_a) or isvariadic(cur_b)):
p1 += 1
p2 += 1
elif isvariadic(cur_a):
p2 += 1
elif isvariadic(cur_b):
p1 += 1
# We only need to check for variadic ends
# Variadic types are guaranteed to be the last element
return isvariadic(cur_a) and p2 == len(b) or isvariadic(cur_b) and p1 == len(a)


def ambiguous(a, b):
""" A is consistent with B but neither is strictly more specific """
return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a))


def ambiguities(signatures):
""" All signature pairs such that A is ambiguous with B """
signatures = list(map(tuple, signatures))
return set(
(a, b)
for a in signatures
for b in signatures
if hash(a) < hash(b)
and ambiguous(a, b)
and not any(supercedes(c, a) and supercedes(c, b) for c in signatures)
)


def super_signature(signatures):
""" A signature that would break ambiguities """
n = len(signatures[0])
assert all(len(s) == n for s in signatures)

return [max([type.mro(sig[i]) for sig in signatures], key=len)[0] for i in range(n)]


def edge(a, b, tie_breaker=hash):
""" A should be checked before B

Tie broken by tie_breaker, defaults to ``hash``
"""
# A either supercedes B and B does not supercede A or if B does then call
# tie_breaker
return supercedes(a, b) and (
not supercedes(b, a) or tie_breaker(a) > tie_breaker(b)
)


def ordering(signatures):
""" A sane ordering of signatures to check, first to last

Topoological sort of edges as given by ``edge`` and ``supercedes``
"""
signatures = list(map(tuple, signatures))
edges = [(a, b) for a in signatures for b in signatures if edge(a, b)]
edges = groupby(lambda x: x[0], edges)
for s in signatures:
if s not in edges:
edges[s] = []
edges = dict((k, [b for a, b in v]) for k, v in edges.items())
return _toposort(edges)

+ 88
- 0
imperative/python/megengine/core/tensor/multipledispatch/core.py View File

@@ -0,0 +1,88 @@
import inspect
import sys

from .dispatcher import Dispatcher, MethodDispatcher, ambiguity_warn

global_namespace = dict()


def dispatch(*types, **kwargs):
""" Dispatch function on the types of the inputs

Supports dispatch on all non-keyword arguments.

Collects implementations based on the function name. Ignores namespaces.

If ambiguous type signatures occur a warning is raised when the function is
defined suggesting the additional method to break the ambiguity.

Examples
--------

>>> @dispatch(int)
... def f(x):
... return x + 1

>>> @dispatch(float)
... def f(x):
... return x - 1

>>> f(3)
4
>>> f(3.0)
2.0

Specify an isolated namespace with the namespace keyword argument

>>> my_namespace = dict()
>>> @dispatch(int, namespace=my_namespace)
... def foo(x):
... return x + 1

Dispatch on instance methods within classes

>>> class MyClass(object):
... @dispatch(list)
... def __init__(self, data):
... self.data = data
... @dispatch(int)
... def __init__(self, datum):
... self.data = [datum]
"""
namespace = kwargs.get("namespace", global_namespace)

types = tuple(types)

def _df(func):
name = func.__name__

if ismethod(func):
dispatcher = inspect.currentframe().f_back.f_locals.get(
name, MethodDispatcher(name),
)
else:
if name not in namespace:
namespace[name] = Dispatcher(name)
dispatcher = namespace[name]

dispatcher.add(types, func)
return dispatcher

return _df


def ismethod(func):
""" Is func a method?

Note that this has to work as the method is defined but before the class is
defined. At this stage methods look like functions.
"""
if hasattr(inspect, "signature"):
signature = inspect.signature(func)
return signature.parameters.get("self", None) is not None
else:
if sys.version_info.major < 3:
spec = inspect.getargspec(func)
else:
spec = inspect.getfullargspec(func)
return spec and spec.args and spec.args[0] == "self"

+ 401
- 0
imperative/python/megengine/core/tensor/multipledispatch/dispatcher.py View File

@@ -0,0 +1,401 @@
import copy
import inspect
import itertools as itl
from warnings import warn

from ..._imperative_rt.dispatcher import Dispatcher as CDispatcher
from .conflict import AmbiguityWarning, ambiguities, ordering, super_signature
from .utils import expand_tuples, parse_union
from .variadic import Variadic, isvariadic


def ambiguity_warn(dispatcher, ambiguities):
""" Raise warning when ambiguity is detected

Parameters
----------
dispatcher : Dispatcher
The dispatcher on which the ambiguity was detected
ambiguities : set
Set of type signature pairs that are ambiguous within this dispatcher

See Also:
Dispatcher.add
warning_text
"""
warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning)


def variadic_signature_matches_iter(types, full_signature):
"""Check if a set of input types matches a variadic signature.

Notes
-----
The algorithm is as follows:

Initialize the current signature to the first in the sequence

For each type in `types`:
If the current signature is variadic
If the type matches the signature
yield True
Else
Try to get the next signature
If no signatures are left we can't possibly have a match
so yield False
Else
yield True if the type matches the current signature
Get the next signature
"""
sigiter = iter(full_signature)
sig = next(sigiter)
for typ in types:
matches = issubclass(typ, sig)
yield matches
if not isvariadic(sig):
# we're not matching a variadic argument, so move to the next
# element in the signature
sig = next(sigiter)
else:
try:
sig = next(sigiter)
except StopIteration:
assert isvariadic(sig)
yield True
else:
# We have signature items left over, so all of our arguments
# haven't matched
yield False


def variadic_signature_matches(types, full_signature):
# No arguments always matches a variadic signature
assert full_signature
return all(variadic_signature_matches_iter(types, full_signature))


def get_func_signature(function):
sig = inspect.signature(function)
types = []
for p in sig.parameters.values():
ann = p.annotation
ann = parse_union(ann) or ann
if p.kind in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
):
types.append(ann)
if p.kind == inspect.Parameter.VAR_POSITIONAL:
types.append([ann])
return tuple(types)


class Frame:
__slots__ = "args", "types", "mro", "mro_offset"


class Dispatcher(CDispatcher):
""" Dispatch methods based on type signature

Use ``dispatch`` to add implementations

Examples
--------

>>> from multipledispatch import dispatch
>>> @dispatch(int)
... def f(x):
... return x + 1

>>> @dispatch(float)
... def f(x):
... return x - 1

>>> f(3)
4
>>> f(3.0)
2.0
"""

__slots__ = "__name__", "name", "funcs", "_ordering", "doc"

def __init__(self, name, doc=None):
self.name = self.__name__ = name
self.funcs = {}
self.doc = doc

def register(self, *types, **kwargs):
""" register dispatcher with new implementation

>>> f = Dispatcher('f')
>>> @f.register(int)
... def inc(x):
... return x + 1

>>> @f.register(float)
... def dec(x):
... return x - 1

>>> @f.register(list)
... @f.register(tuple)
... def reverse(x):
... return x[::-1]

>>> f(1)
2

>>> f(1.0)
0.0

>>> f([1, 2, 3])
[3, 2, 1]
"""

def _df(func):
self.add(types, func, **kwargs)
return func

return _df

def add(self, signature, func):
""" Add new types/method pair to dispatcher

>>> D = Dispatcher('add')
>>> D.add((int, int), lambda x, y: x + y)
>>> D.add((float, float), lambda x, y: x + y)

>>> D(1, 2)
3
>>> D(1, 2.0)
Traceback (most recent call last):
...
NotImplementedError: Could not find signature for add: <int, float>

When ``add`` detects a warning it calls the ``on_ambiguity`` callback
with a dispatcher/itself, and a set of ambiguous type signature pairs
as inputs. See ``ambiguity_warn`` for an example.
"""
# Handle annotations
if not signature:
signature = get_func_signature(func)

# Handle union types
if any(isinstance(typ, tuple) for typ in signature):
for typs in expand_tuples(signature):
self.add(typs, func)
return

new_signature = []

for index, typ in enumerate(signature, start=1):
if not isinstance(typ, (type, list)):
str_sig = ", ".join(
c.__name__ if isinstance(c, type) else str(c) for c in signature
)
raise TypeError(
"Tried to dispatch on non-type: %s\n"
"In signature: <%s>\n"
"In function: %s" % (typ, str_sig, self.name)
)

# handle variadic signatures
if isinstance(typ, list):
if index != len(signature):
raise TypeError("Variadic signature must be the last element")

if len(typ) != 1:
raise TypeError(
"Variadic signature must contain exactly one element. "
"To use a variadic union type place the desired types "
"inside of a tuple, e.g., [(int, str)]"
)
new_signature.append(Variadic[typ[0]])
else:
new_signature.append(typ)

l = self.funcs.setdefault(tuple(new_signature), [])
for i in l:
if i is func:
raise ValueError("already registered")
l.append(func)
self.enable(func)
self.clear_cache()

try:
del self._ordering
except AttributeError:
pass

@property
def ordering(self):
try:
return self._ordering
except AttributeError:
return self.reorder()

def reorder(self, on_ambiguity=ambiguity_warn):
self._ordering = od = ordering(self.funcs)
amb = ambiguities(self.funcs)
if amb:
on_ambiguity(self, amb)
return od

def __str__(self):
return "<dispatched %s>" % self.name

__repr__ = __str__

def dispatch(self, *types):
"""Deterimine appropriate implementation for this type signature

This method is internal. Users should call this object as a function.
Implementation resolution occurs within the ``__call__`` method.

>>> from multipledispatch import dispatch
>>> @dispatch(int)
... def inc(x):
... return x + 1

>>> implementation = inc.dispatch(int)
>>> implementation(3)
4

>>> print(inc.dispatch(float))
None

See Also:
``multipledispatch.conflict`` - module to determine resolution order
"""

if types in self.funcs:
return self.funcs[types][-1]

for f in self.dispatch_iter(*types):
return f

def dispatch_iter(self, *types):

n = len(types)
for signature in self.ordering:
if (
len(signature) == n
and all(map(issubclass, types, signature))
or len(signature)
and isvariadic(signature[-1])
and variadic_signature_matches(types, signature)
):
yield from self.funcs[signature][::-1]

def __getstate__(self):
return {"name": self.name, "funcs": self.funcs}

def __setstate__(self, d):
self.name = d["name"]
self.funcs = d["funcs"]
self._ordering = ordering(self.funcs)
self._cache = dict()

@property
def __doc__(self):
docs = ["Multiply dispatched method: %s" % self.name]

if self.doc:
docs.append(self.doc)

other = []
for sig in self.ordering[::-1]:
funcs = self.funcs[sig]
s = "Inputs: <%s>\n" % str_signature(sig)
sep = "-" * len(s) + "\n"
for i, func in enumerate(funcs):
s += sep
if len(funcs) > 1:
s += "[Handler %d]\n\n" % (i + 1)
if i:
s += "\n\n"
if func.__doc__:
s += func.__doc__.strip()
else:
s += repr(func) + "\n"
docs.append(s)

return "\n\n".join(docs)

def _help(self, *args):
return self.dispatch(*map(type, args)).__doc__

def help(self, *args, **kwargs):
""" Print docstring for the function corresponding to inputs """
print(self._help(*args))

def _source(self, *args):
func = self.dispatch(*map(type, args))
if not func:
raise TypeError("No function found")
return source(func)

def source(self, *args, **kwargs):
""" Print source code for the function corresponding to inputs """
print(self._source(*args))


def source(func):
s = "File: %s\n\n" % inspect.getsourcefile(func)
s = s + inspect.getsource(func)
return s


class MethodDispatcher(Dispatcher):
""" Dispatch methods based on type signature

See Also:
Dispatcher
"""

__slots__ = ("obj", "cls")

@classmethod
def get_func_params(cls, func):
if hasattr(inspect, "signature"):
sig = inspect.signature(func)
return itl.islice(sig.parameters.values(), 1, None)

def __get__(self, instance, owner):
self.obj = instance
self.cls = owner
return self

def __call__(self, *args, **kwargs):
types = tuple([type(arg) for arg in args])
func = self.dispatch(*types)
if not func:
raise NotImplementedError(
"Could not find signature for %s: <%s>"
% (self.name, str_signature(types))
)
return func(self.obj, *args, **kwargs)


def str_signature(sig):
""" String representation of type signature

>>> str_signature((int, float))
'int, float'
"""
return ", ".join(cls.__name__ for cls in sig)


def warning_text(name, amb):
""" The text for ambiguity warnings """
text = "\nAmbiguities exist in dispatched function %s\n\n" % (name)
text += "The following signatures may result in ambiguous behavior:\n"
for pair in amb:
text += "\t" + ", ".join("[" + str_signature(s) + "]" for s in pair) + "\n"
text += "\n\nConsider making the following additions:\n\n"
text += "\n\n".join(
[
"@dispatch(" + str_signature(super_signature(s)) + ")\ndef %s(...)" % name
for s in amb
]
)
return text

+ 177
- 0
imperative/python/megengine/core/tensor/multipledispatch/utils.py View File

@@ -0,0 +1,177 @@
import sys
import typing
from collections import OrderedDict


def raises(err, lamda):
try:
lamda()
return False
except err:
return True


def expand_tuples(L):
"""

>>> expand_tuples([1, (2, 3)])
[(1, 2), (1, 3)]

>>> expand_tuples([1, 2])
[(1, 2)]
"""
if not L:
return [()]
elif not isinstance(L[0], tuple):
rest = expand_tuples(L[1:])
return [(L[0],) + t for t in rest]
else:
rest = expand_tuples(L[1:])
return [(item,) + t for t in rest for item in L[0]]


# Taken from theano/theano/gof/sched.py
# Avoids licensing issues because this was written by Matthew Rocklin
def _toposort(edges):
""" Topological sort algorithm by Kahn [1] - O(nodes + vertices)

inputs:
edges - a dict of the form {a: {b, c}} where b and c depend on a
outputs:
L - an ordered list of nodes that satisfy the dependencies of edges

>>> _toposort({1: (2, 3), 2: (3, )})
[1, 2, 3]

Closely follows the wikipedia page [2]

[1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
Communications of the ACM
[2] http://en.wikipedia.org/wiki/Toposort#Algorithms
"""
incoming_edges = reverse_dict(edges)
incoming_edges = OrderedDict((k, set(val)) for k, val in incoming_edges.items())
S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges)
L = []

while S:
n, _ = S.popitem()
L.append(n)
for m in edges.get(n, ()):
assert n in incoming_edges[m]
incoming_edges[m].remove(n)
if not incoming_edges[m]:
S[m] = None
if any(incoming_edges.get(v, None) for v in edges):
raise ValueError("Input has cycles")
return L


def reverse_dict(d):
"""Reverses direction of dependence dict

>>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
>>> reverse_dict(d) # doctest: +SKIP
{1: ('a',), 2: ('a', 'b'), 3: ('b',)}

:note: dict order are not deterministic. As we iterate on the
input dict, it make the output of this function depend on the
dict order. So this function output order should be considered
as undeterministic.

"""
result = OrderedDict()
for key in d:
for val in d[key]:
result[val] = result.get(val, tuple()) + (key,)
return result


# Taken from toolz
# Avoids licensing issues because this version was authored by Matthew Rocklin
def groupby(func, seq):
""" Group a collection by a key function

>>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
>>> groupby(len, names) # doctest: +SKIP
{3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}

>>> iseven = lambda x: x % 2 == 0
>>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP
{False: [1, 3, 5, 7], True: [2, 4, 6, 8]}

See Also:
``countby``
"""

d = OrderedDict()
for item in seq:
key = func(item)
if key not in d:
d[key] = list()
d[key].append(item)
return d


def typename(type):
"""Get the name of `type`.

Parameters
----------
type : Union[Type, Tuple[Type]]

Returns
-------
str
The name of `type` or a tuple of the names of the types in `type`.

Examples
--------
>>> typename(int)
'int'
>>> typename((int, float))
'(int, float)'
"""
try:
return type.__name__
except AttributeError:
if len(type) == 1:
return typename(*type)
return "(%s)" % ", ".join(map(typename, type))


# parse typing.Union
if sys.version_info < (3, 6):

def parse_union(ann):
if type(ann) is not typing.UnionMeta:
return
return ann.__union_params__


elif sys.version_info < (3, 7):

def parse_union(ann):
if type(ann) is not typing._Union:
return
return ann.__args__


elif sys.version_info < (3, 8):

def parse_union(ann):
if type(ann) is not typing._GenericAlias:
if type(ann) is not typing.Union:
return
else:
if ann.__origin__ is not typing.Union:
return
return ann.__args__


else:

def parse_union(ann):
if typing.get_origin(ann) is not typing.Union:
return
return typing.get_args(ann)

+ 95
- 0
imperative/python/megengine/core/tensor/multipledispatch/variadic.py View File

@@ -0,0 +1,95 @@
from .utils import typename


class VariadicSignatureType(type):
# checking if subclass is a subclass of self
def __subclasscheck__(self, subclass):
other_type = subclass.variadic_type if isvariadic(subclass) else (subclass,)
return subclass is self or all(
issubclass(other, self.variadic_type) for other in other_type
)

def __eq__(self, other):
"""
Return True if other has the same variadic type

Parameters
----------
other : object (type)
The object (type) to check

Returns
-------
bool
Whether or not `other` is equal to `self`
"""
return isvariadic(other) and set(self.variadic_type) == set(other.variadic_type)

def __hash__(self):
return hash((type(self), frozenset(self.variadic_type)))


def isvariadic(obj):
"""Check whether the type `obj` is variadic.

Parameters
----------
obj : type
The type to check

Returns
-------
bool
Whether or not `obj` is variadic

Examples
--------
>>> isvariadic(int)
False
>>> isvariadic(Variadic[int])
True
"""
return isinstance(obj, VariadicSignatureType)


class VariadicSignatureMeta(type):
"""A metaclass that overrides ``__getitem__`` on the class. This is used to
generate a new type for Variadic signatures. See the Variadic class for
examples of how this behaves.
"""

def __getitem__(self, variadic_type):
if not (isinstance(variadic_type, (type, tuple)) or type(variadic_type)):
raise ValueError(
"Variadic types must be type or tuple of types"
" (Variadic[int] or Variadic[(int, float)]"
)

if not isinstance(variadic_type, tuple):
variadic_type = (variadic_type,)
return VariadicSignatureType(
"Variadic[%s]" % typename(variadic_type),
(),
dict(variadic_type=variadic_type, __slots__=()),
)


class Variadic(metaclass=VariadicSignatureMeta):
"""A class whose getitem method can be used to generate a new type
representing a specific variadic signature.

Examples
--------
>>> Variadic[int] # any number of int arguments
<class 'multipledispatch.variadic.Variadic[int]'>
>>> Variadic[(int, str)] # any number of one of int or str arguments
<class 'multipledispatch.variadic.Variadic[(int, str)]'>
>>> issubclass(int, Variadic[int])
True
>>> issubclass(int, Variadic[(int, str)])
True
>>> issubclass(str, Variadic[(int, str)])
True
>>> issubclass(float, Variadic[(int, str)])
False
"""

+ 2
- 2
imperative/python/megengine/core/tensor/raw_tensor/__init__.py View File

@@ -66,13 +66,13 @@ class RawTensor(TensorBase):
delete(self._handle)


@apply.add
@apply.register()
def _(op: OpDef, *args: RawTensor):
outputs = apply_op(op, tuple(i._handle for i in args))
return tuple(map(RawTensor, outputs))


@apply.add
@apply.register()
def _(op: Const, *args: RawTensor):
dtype = op.dtype
device = as_device(op.device).to_c()


+ 1
- 1
imperative/python/megengine/core/tensor/tensor.py View File

@@ -79,7 +79,7 @@ def get_context():
return _context


@apply.add
@apply.register()
def tensor_apply(op: OpBase, *args: Tensor):
data = tuple(i._data if isinstance(i, Tensor) else i for i in args)
# type(Tensor._data) is RawTensor


+ 1
- 1
imperative/python/megengine/functional/distributed.py View File

@@ -46,7 +46,7 @@ __all__ = [
]


@apply.add
@apply.register()
def _(op: RemoteSend, *args: Tensor):
ret = tensor_apply(op, *args)



+ 0
- 1
imperative/python/requires.txt View File

@@ -1,5 +1,4 @@
numpy>=1.18
multipledispatch==0.6.0
opencv-python
pyarrow
requests


+ 180
- 0
imperative/python/src/dispatcher.cpp View File

@@ -0,0 +1,180 @@
#include "./dispatcher.h"
#include "./pyext17.h"
#include "megbrain/utils/hash.h"
#include "megbrain/utils/small_vector.h"

#include <unordered_map>
#include <structmember.h>

namespace py = pybind11;
namespace pyx = pyext17;

namespace {

struct Handler {
PyObject* func; // borrowed
bool enabled;

Handler() = default;
Handler(PyObject* func_, bool enable = true) : func(func_), enabled(enable) {}
};

using FastSig = mgb::SmallVector<void*, 8>;
using MRO = std::vector<Handler*>;

struct Frame {
MRO* mro;
size_t mro_offset;

Frame() = default;
Frame(MRO* mro_, size_t mro_offset_ = 0) : mro(mro_), mro_offset(mro_offset_) {}
};

struct FastSigHash {
size_t operator()(const FastSig& sig) const {
auto* ptr = &sig.front();
return mgb::XXHash()
.update(ptr, sig.size() * sizeof(FastSig::value_type))
.digest();
}
};

struct ObjectIdHash : std::hash<void*> {
size_t operator()(const py::handle& h) const {
return std::hash<void*>::operator()(h.ptr());
}
};

struct Dispatcher {
std::unordered_map<FastSig, std::unique_ptr<MRO>, FastSigHash> cache;
std::vector<Frame> stack;
std::unordered_map<py::object, std::unique_ptr<Handler>, ObjectIdHash> registry;

inline py::handle self() {
return pyx::wrap<Dispatcher>::pycast(this);
}

bool prepare_call(PyObject*const* args, Py_ssize_t nargs) {
FastSig sig(nargs);
for (Py_ssize_t i = 0; i < nargs; ++i) {
sig[i] = Py_TYPE(args[i]);
}
auto it = cache.find(sig);
if (it == cache.end()) {
if (auto mro = resolve(sig)) {
it = cache.emplace(std::move(sig), std::move(mro)).first;
} else {
return false;
}
}
stack.emplace_back(it->second.get());
return true;
}

template<typename T>
PyObject* do_call(T&& caller) {
auto& frame = stack.back();
auto& mro = *frame.mro;
auto& i = frame.mro_offset;
for (; i < mro.size(); ++i) {
if (mro[i]->enabled) {
auto ret = caller(mro[i]->func);
if (ret != Py_NotImplemented) {
stack.pop_back();
return ret;
}
Py_DECREF(ret);
}
}
PyErr_SetString(PyExc_NotImplementedError, "mro exhausted");
stack.pop_back();
return nullptr;
}

std::unique_ptr<MRO> resolve(const FastSig& sig) {
try {
py::tuple args(sig.size());
for (size_t i = 0; i < sig.size(); ++i) {
args[i] = (PyObject*)sig[i];
}
auto mro_iter = self().attr("dispatch_iter")(*args);
auto ret = std::make_unique<MRO>();
for (auto i : mro_iter) {
auto it = registry.find(py::reinterpret_borrow<py::object>(i));
if (it == registry.end()) {
PyErr_SetString(PyExc_RuntimeError, "resolved to unregistered function");
return nullptr;
}
ret->push_back(it->second.get());
}
return ret;
} catch (py::error_already_set& e) {
e.restore();
} catch (std::runtime_error& e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
}
return nullptr;
}

public:
static constexpr auto tp_name = "Dispatcher";

PyObject* tp_vectorcall(PyObject*const* args, Py_ssize_t nargs) {
if (!prepare_call(args, nargs)) return nullptr;
return do_call([=](PyObject* func){return _PyObject_FastCall(func, args, nargs);});
}

PyObject* tp_call(PyObject* args, PyObject* kwargs) {
if (!prepare_call(&PyTuple_GET_ITEM(args, 0), PyTuple_GET_SIZE(args))) return nullptr;
return do_call([=](PyObject* func){return PyObject_Call(func, args, kwargs);});
}

PyObject* super(PyObject*const* args, Py_ssize_t nargs) {
if (stack.empty()) {
PyErr_SetString(PyExc_RuntimeError, "super called at top level");
return nullptr;
}
stack.emplace_back(stack.back()).mro_offset++;
return do_call([=](PyObject* func){return _PyObject_FastCall(func, args, nargs);});
}

void enable(PyObject* func) {
auto obj = py::reinterpret_borrow<py::object>(func);
auto it = registry.find(obj);
if (it != registry.end()) {
it->second->enabled = true;
} else {
registry.emplace(std::move(obj), std::make_unique<Handler>(func));
}
}

PyObject* disable(PyObject* func) {
auto obj = py::reinterpret_borrow<py::object>(func);
auto it = registry.find(obj);
if (it == registry.end()) {
PyErr_SetString(PyExc_ValueError, "function not registered");
return nullptr;
} else {
it->second->enabled = false;
}
Py_RETURN_NONE;
}

void clear_cache() {
cache.clear();
}
};

} // namespace

void init_dispatcher(py::module m) {
auto* dispatcher_type = pyx::wrap<Dispatcher>::type()
.def<&Dispatcher::enable>("enable")
.def<&Dispatcher::disable>("disable")
.def<&Dispatcher::clear_cache>("clear_cache")
.def<&Dispatcher::tp_vectorcall>("call")
.def<&Dispatcher::super>("super")
.finalize();
if (!dispatcher_type) throw py::error_already_set();
m.attr("Dispatcher") = dispatcher_type;
}

+ 5
- 0
imperative/python/src/dispatcher.h View File

@@ -0,0 +1,5 @@
#pragma once

#include <pybind11/pybind11.h>

void init_dispatcher(pybind11::module);

+ 4
- 0
imperative/python/src/module.cpp View File

@@ -21,6 +21,8 @@
#include "./graph_rt.h"
#include "./ops.h"

#include "./dispatcher.h"

namespace py = pybind11;

#ifndef MODULE_NAME
@@ -63,4 +65,6 @@ PYBIND11_MODULE(MODULE_NAME, m) {
from .graph import *
)",
py::getattr(m, "__dict__"));

init_dispatcher(submodule(m, "dispatcher"));
}

+ 270
- 0
imperative/python/src/pyext17.h View File

@@ -0,0 +1,270 @@
#pragma once

#include <stdexcept>
#include <vector>
#include <utility>
#include <Python.h>

namespace pyext17 {

#ifdef METH_FASTCALL
constexpr bool has_fastcall = true;
#else
constexpr bool has_fastcall = false;
#endif

template<typename... Args>
struct invocable_with {
template<typename T>
constexpr bool operator()(T&& lmb) {
return std::is_invocable_v<T, Args...>;
}
};

#define HAS_MEMBER_TYPE(T, U) invocable_with<T>{}([](auto&& x) -> typename std::decay_t<decltype(x)>::U {})
#define HAS_MEMBER(T, m) invocable_with<T>{}([](auto&& x) -> decltype(&std::decay_t<decltype(x)>::m) {})

inline PyObject* cvt_retval(PyObject* rv) {
return rv;
}

#define CVT_RET_PYOBJ(...) \
if constexpr (std::is_same_v<decltype(__VA_ARGS__), void>) { \
__VA_ARGS__; \
Py_RETURN_NONE; \
} else { \
return cvt_retval(__VA_ARGS__); \
}

template <typename T>
struct wrap {
private:
typedef wrap<T> wrap_t;

public:
PyObject_HEAD
std::aligned_storage_t<sizeof(T), alignof(T)> storage;

inline T* inst() {
return reinterpret_cast<T*>(&storage);
}

inline static PyObject* pycast(T* ptr) {
return (PyObject*)((char*)ptr - offsetof(wrap_t, storage));
}

private:
// method wrapper

enum struct meth_type {
noarg,
varkw,
fastcall,
singarg
};

template<auto f>
struct detect_meth_type {
static constexpr meth_type value = []() {
using F = decltype(f);
static_assert(std::is_member_function_pointer_v<F>);
if constexpr (std::is_invocable_v<F, T>) {
return meth_type::noarg;
} else if constexpr (std::is_invocable_v<F, T, PyObject*, PyObject*>) {
return meth_type::varkw;
} else if constexpr (std::is_invocable_v<F, T, PyObject*const*, Py_ssize_t>) {
return meth_type::fastcall;
} else if constexpr (std::is_invocable_v<F, T, PyObject*>) {
return meth_type::singarg;
} else {
static_assert(!std::is_same_v<F, F>);
}
}();
};

template<meth_type, auto f>
struct meth {};

template<auto f>
struct meth<meth_type::noarg, f> {
static constexpr int flags = METH_NOARGS;

static PyObject* impl(PyObject* self, PyObject*) {
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
CVT_RET_PYOBJ((inst->*f)());
}
};

template<auto f>
struct meth<meth_type::varkw, f> {
static constexpr int flags = METH_VARARGS | METH_KEYWORDS;

static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) {
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
CVT_RET_PYOBJ((inst->*f)(args, kwargs));
}
};

template<auto f>
struct meth<meth_type::fastcall, f> {
#ifdef METH_FASTCALL
static constexpr int flags = METH_FASTCALL;

static PyObject* impl(PyObject* self, PyObject*const* args, Py_ssize_t nargs) {
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
CVT_RET_PYOBJ((inst->*f)(args, nargs));
}
#else
static constexpr int flags = METH_VARARGS;

static PyObject* impl(PyObject* self, PyObject* args) {
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
auto* arr = &PyTuple_GET_ITEM(args, 0);
auto size = PyTuple_GET_SIZE(args);
CVT_RET_PYOBJ((inst->*f)(arr, size));
}
#endif
};

template<auto f>
struct meth<meth_type::singarg, f> {
static constexpr int flags = METH_O;

static PyObject* impl(PyObject* self, PyObject* obj) {
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
CVT_RET_PYOBJ((inst->*f)(obj));
}
};

template<auto f>
static constexpr PyMethodDef make_meth_def(const char* name, const char* doc = nullptr) {
using M = meth<detect_meth_type<f>::value, f>;
return {name, (PyCFunction)M::impl, M::flags, doc};
}

// polyfills

struct tp_new {
static constexpr bool provided = HAS_MEMBER(T, tp_new);
static constexpr bool varkw = std::is_constructible_v<T, PyObject*, PyObject*>;
static constexpr bool noarg = std::is_default_constructible_v<T>;

template<typename = void>
static PyObject* impl(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
auto* self = type->tp_alloc(type, 0);
auto* ptr = reinterpret_cast<wrap_t*>(self)->inst();
if constexpr (varkw) {
new(ptr) T(args, kwargs);
} else {
new(ptr) T();
}
return self;
}

static constexpr newfunc value = []() {if constexpr (provided) return T::tp_new;
else if constexpr (varkw || noarg) return impl<>;
else return nullptr;}();
};

struct tp_dealloc {
static constexpr bool provided = HAS_MEMBER(T, tp_dealloc);

template<typename = void>
static void impl(PyObject* self) {
reinterpret_cast<wrap_t*>(self)->inst()->~T();
Py_TYPE(self)->tp_free(self);
}

static constexpr destructor value = []() {if constexpr (provided) return T::tp_dealloc;
else return impl<>;}();
};

struct tp_call {
static constexpr bool valid = HAS_MEMBER(T, tp_call);
static constexpr bool static_form = invocable_with<T, PyObject*, PyObject*, PyObject*>{}(
[](auto&& t, auto... args) -> decltype(std::decay_t<decltype(t)>::tp_call(args...)) {});

template<typename = void>
static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) {
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
CVT_RET_PYOBJ(inst->tp_call(args, kwargs));
}

static constexpr ternaryfunc value = []() {if constexpr (static_form) return T::tp_call;
else if constexpr (valid) return impl<>;
else return nullptr;}();
};

public:
class TypeBuilder {
std::vector<PyMethodDef> m_methods;
PyTypeObject m_type;
bool m_finalized = false;
bool m_ready = false;

void check_finalized() {
if (m_finalized) {
throw std::runtime_error("type is already finalized");
}
}
public:
TypeBuilder(const TypeBuilder&) = delete;
TypeBuilder& operator=(const TypeBuilder&) = delete;

TypeBuilder() : m_type{PyVarObject_HEAD_INIT(nullptr, 0)} {
// static_assert(HAS_MEMBER(T, tp_name));
if constexpr (HAS_MEMBER(T, tp_name)) {
m_type.tp_name = T::tp_name;
}
m_type.tp_dealloc = tp_dealloc::value;
m_type.tp_call = tp_call::value;
m_type.tp_basicsize = sizeof(wrap_t);
m_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
m_type.tp_new = tp_new::value;
}

PyTypeObject* operator->() {
return &m_type;
}

bool ready() const {
return m_ready;
}

PyObject* finalize() {
if (!m_finalized) {
if (m_methods.size()) {
m_methods.push_back({0});
if (m_type.tp_methods) {
PyErr_SetString(PyExc_SystemError, "tp_method is already set");
return nullptr;
}
m_type.tp_methods = &m_methods[0];
}
if (PyType_Ready(&m_type)) {
return nullptr;
}
m_ready = true;
}
return (PyObject*)&m_type;
}

template<auto f>
TypeBuilder& def(const char* name, const char* doc = nullptr) {
check_finalized();
m_methods.push_back(make_meth_def<f>(name, doc));
return *this;
}
};

static TypeBuilder& type() {
static TypeBuilder type_helper;
return type_helper;
}
};

} // namespace pyext17

#undef HAS_MEMBER_TYPE
#undef HAS_MEMBER
#undef CVT_RET_PYOBJ

+ 58
- 0
imperative/python/test/unit/test_dispatch.py View File

@@ -0,0 +1,58 @@
from megengine.core.tensor.multipledispatch import Dispatcher


def test_register_many():
f = Dispatcher("f")

log = []

@f.register()
def _(x: int):
log.append("a")
return log[-1]

@f.register()
def _(x: int):
log.append("b")
return log[-1]

assert f(0) == "b"
assert log == ["b"]


def test_return_not_implemented():
f = Dispatcher("f")

log = []

@f.register()
def _(x: int):
log.append("a")
return log[-1]

@f.register()
def _(x: int):
log.append("b")
return NotImplemented

assert f(0) == "a"
assert log == ["b", "a"]


def test_super():
f = Dispatcher("f")

log = []

@f.register()
def _(x: int):
log.append("a")
return log[-1]

@f.register()
def _(x: int):
log.append("b")
return f.super(x)

assert f(0) == "a"
assert log == ["b", "a"]

Loading…
Cancel
Save