GitOrigin-RevId: a7c25a4302
tags/v1.0.0-rc1
@@ -210,6 +210,44 @@ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND | |||
********************************************************************************************************************************* | |||
multipledispatch | |||
-------------------------------------------------------------------- | |||
Copyright (c) 2014 Matthew Rocklin | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are met: | |||
a. Redistributions of source code must retain the above copyright notice, | |||
this list of conditions and the following disclaimer. | |||
b. Redistributions in binary form must reproduce the above copyright | |||
notice, this list of conditions and the following disclaimer in the | |||
documentation and/or other materials provided with the distribution. | |||
c. Neither the name of multipledispatch nor the names of its contributors | |||
may be used to endorse or promote products derived from this software | |||
without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR | |||
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT | |||
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY | |||
OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH | |||
DAMAGE. | |||
********************************************************************************************************************************* | |||
********************************************************************************************************************************* | |||
Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-party Components therein: | |||
-------------------------------------------------------------------- | |||
protobuf | |||
@@ -343,7 +343,7 @@ def default_has_grad_fn(opnode, reached): | |||
return False | |||
@apply.add | |||
@apply.register() | |||
def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]): | |||
args = tuple(i if isinstance(i, Tracer) else None for i in args) | |||
input_requires_grad = list(map(bool, args)) | |||
@@ -385,6 +385,6 @@ def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]): | |||
return tuple(outputs) | |||
@apply.add | |||
@apply.register() | |||
def _(op: Const, *_: typing.Optional[Tracer]): | |||
return None |
@@ -19,7 +19,7 @@ from .._internal.helper import PodOpVisitor | |||
OpBase.register(OpDef) | |||
# forward to apply(OpDef, ...) | |||
@apply.add | |||
@apply.register() | |||
def _(op: PodOpVisitor, *args: Union[TensorBase, TensorWrapperBase]): | |||
return apply(op.to_c(), *args) | |||
@@ -13,7 +13,7 @@ import sys | |||
import typing | |||
from abc import ABC | |||
import multipledispatch | |||
from .multipledispatch import Dispatcher | |||
class OpBase(ABC): | |||
@@ -29,84 +29,17 @@ class TensorWrapperBase: | |||
pass | |||
class Dispatcher(multipledispatch.Dispatcher): | |||
def add(self, f, g=None): | |||
if g is None: | |||
super().add(get_signature(f), f) | |||
else: | |||
super().add(f, g) | |||
return f | |||
def __get__(self, instance, owner=None): | |||
if instance is not None: | |||
return self | |||
return functools.partial(self, instance) | |||
if sys.version_info < (3, 6): | |||
def parse_union(ann): | |||
if type(ann) is not typing.UnionMeta: | |||
return | |||
return ann.__union_params__ | |||
elif sys.version_info < (3, 7): | |||
def parse_union(ann): | |||
if type(ann) is not typing._Union: | |||
return | |||
return ann.__args__ | |||
elif sys.version_info < (3, 8): | |||
def parse_union(ann): | |||
if type(ann) is not typing._GenericAlias: | |||
if type(ann) is not typing.Union: | |||
return | |||
else: | |||
if ann.__origin__ is not typing.Union: | |||
return | |||
return ann.__args__ | |||
else: | |||
def parse_union(ann): | |||
if typing.get_origin(ann) is not typing.Union: | |||
return | |||
return typing.get_args(ann) | |||
def get_signature(function, op_type=None): | |||
sig = inspect.signature(function) | |||
types = [] | |||
for p in sig.parameters.values(): | |||
ann = p.annotation | |||
ann = parse_union(ann) or ann | |||
if p.kind in ( | |||
inspect.Parameter.POSITIONAL_ONLY, | |||
inspect.Parameter.POSITIONAL_OR_KEYWORD, | |||
): | |||
types.append(ann) | |||
if p.kind == inspect.Parameter.VAR_POSITIONAL: | |||
types.append([ann]) | |||
return tuple(types) | |||
apply = Dispatcher("apply") | |||
OpBase.apply = apply | |||
@apply.add | |||
@apply.register() | |||
def _(op: OpBase, *args: TensorBase): | |||
raise NotImplementedError | |||
@apply.add | |||
@apply.register() | |||
def _(op: OpBase, *args: TensorWrapperBase): | |||
assert args | |||
Wrapper = type(args[0]) | |||
@@ -102,7 +102,7 @@ class Function: | |||
Function.apply = Function.__call__ | |||
@apply.add | |||
@apply.register() | |||
def _(op: Function, *args: TensorWrapperBase): | |||
assert args | |||
Wrapper = type(args[0]) | |||
@@ -148,11 +148,11 @@ def _(op: Function, *args: TensorWrapperBase): | |||
return tuple(map(Wrapper, outputs)) | |||
@apply.add | |||
@apply.register() | |||
def _(op: Function, *args: Tensor): | |||
raise NotImplementedError | |||
@apply.add | |||
@apply.register() | |||
def _(op: Function, *args: RawTensor): | |||
raise NotImplementedError |
@@ -111,7 +111,7 @@ def _unwrap(x): | |||
return x._node | |||
@apply.add | |||
@apply.register() | |||
def _(op: OpDef, *args: VarNode): | |||
outputs = _imperative_rt.invoke_op(op, _unwrap(args)) | |||
return _wrap(outputs) | |||
@@ -0,0 +1,10 @@ | |||
# This directory is a fork of multipledispatch. | |||
# | |||
# Repo: https://github.com/mrocklin/multipledispatch | |||
# Commit: 9e3c87d0cee57972fd5cc33fe5cacde77c781834 | |||
# Authors: Matthew Rocklin et al. | |||
# | |||
# Refer to ACKNOWLEDGEMENT for copyright and liscense information | |||
from .core import dispatch | |||
from .dispatcher import Dispatcher |
@@ -0,0 +1,121 @@ | |||
from .utils import _toposort, groupby | |||
from .variadic import isvariadic | |||
class AmbiguityWarning(Warning): | |||
pass | |||
def supercedes(a, b): | |||
""" A is consistent and strictly more specific than B """ | |||
if len(a) < len(b): | |||
# only case is if a is empty and b is variadic | |||
return not a and len(b) == 1 and isvariadic(b[-1]) | |||
elif len(a) == len(b): | |||
return all(map(issubclass, a, b)) | |||
else: | |||
# len(a) > len(b) | |||
p1 = 0 | |||
p2 = 0 | |||
while p1 < len(a) and p2 < len(b): | |||
cur_a = a[p1] | |||
cur_b = b[p2] | |||
if not (isvariadic(cur_a) or isvariadic(cur_b)): | |||
if not issubclass(cur_a, cur_b): | |||
return False | |||
p1 += 1 | |||
p2 += 1 | |||
elif isvariadic(cur_a): | |||
assert p1 == len(a) - 1 | |||
return p2 == len(b) - 1 and issubclass(cur_a, cur_b) | |||
elif isvariadic(cur_b): | |||
assert p2 == len(b) - 1 | |||
if not issubclass(cur_a, cur_b): | |||
return False | |||
p1 += 1 | |||
return p2 == len(b) - 1 and p1 == len(a) | |||
def consistent(a, b): | |||
""" It is possible for an argument list to satisfy both A and B """ | |||
# Need to check for empty args | |||
if not a: | |||
return not b or isvariadic(b[0]) | |||
if not b: | |||
return not a or isvariadic(a[0]) | |||
# Non-empty args check for mutual subclasses | |||
if len(a) == len(b): | |||
return all(issubclass(aa, bb) or issubclass(bb, aa) for aa, bb in zip(a, b)) | |||
else: | |||
p1 = 0 | |||
p2 = 0 | |||
while p1 < len(a) and p2 < len(b): | |||
cur_a = a[p1] | |||
cur_b = b[p2] | |||
if not issubclass(cur_b, cur_a) and not issubclass(cur_a, cur_b): | |||
return False | |||
if not (isvariadic(cur_a) or isvariadic(cur_b)): | |||
p1 += 1 | |||
p2 += 1 | |||
elif isvariadic(cur_a): | |||
p2 += 1 | |||
elif isvariadic(cur_b): | |||
p1 += 1 | |||
# We only need to check for variadic ends | |||
# Variadic types are guaranteed to be the last element | |||
return isvariadic(cur_a) and p2 == len(b) or isvariadic(cur_b) and p1 == len(a) | |||
def ambiguous(a, b): | |||
""" A is consistent with B but neither is strictly more specific """ | |||
return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a)) | |||
def ambiguities(signatures): | |||
""" All signature pairs such that A is ambiguous with B """ | |||
signatures = list(map(tuple, signatures)) | |||
return set( | |||
(a, b) | |||
for a in signatures | |||
for b in signatures | |||
if hash(a) < hash(b) | |||
and ambiguous(a, b) | |||
and not any(supercedes(c, a) and supercedes(c, b) for c in signatures) | |||
) | |||
def super_signature(signatures): | |||
""" A signature that would break ambiguities """ | |||
n = len(signatures[0]) | |||
assert all(len(s) == n for s in signatures) | |||
return [max([type.mro(sig[i]) for sig in signatures], key=len)[0] for i in range(n)] | |||
def edge(a, b, tie_breaker=hash): | |||
""" A should be checked before B | |||
Tie broken by tie_breaker, defaults to ``hash`` | |||
""" | |||
# A either supercedes B and B does not supercede A or if B does then call | |||
# tie_breaker | |||
return supercedes(a, b) and ( | |||
not supercedes(b, a) or tie_breaker(a) > tie_breaker(b) | |||
) | |||
def ordering(signatures): | |||
""" A sane ordering of signatures to check, first to last | |||
Topoological sort of edges as given by ``edge`` and ``supercedes`` | |||
""" | |||
signatures = list(map(tuple, signatures)) | |||
edges = [(a, b) for a in signatures for b in signatures if edge(a, b)] | |||
edges = groupby(lambda x: x[0], edges) | |||
for s in signatures: | |||
if s not in edges: | |||
edges[s] = [] | |||
edges = dict((k, [b for a, b in v]) for k, v in edges.items()) | |||
return _toposort(edges) |
@@ -0,0 +1,88 @@ | |||
import inspect | |||
import sys | |||
from .dispatcher import Dispatcher, MethodDispatcher, ambiguity_warn | |||
global_namespace = dict() | |||
def dispatch(*types, **kwargs): | |||
""" Dispatch function on the types of the inputs | |||
Supports dispatch on all non-keyword arguments. | |||
Collects implementations based on the function name. Ignores namespaces. | |||
If ambiguous type signatures occur a warning is raised when the function is | |||
defined suggesting the additional method to break the ambiguity. | |||
Examples | |||
-------- | |||
>>> @dispatch(int) | |||
... def f(x): | |||
... return x + 1 | |||
>>> @dispatch(float) | |||
... def f(x): | |||
... return x - 1 | |||
>>> f(3) | |||
4 | |||
>>> f(3.0) | |||
2.0 | |||
Specify an isolated namespace with the namespace keyword argument | |||
>>> my_namespace = dict() | |||
>>> @dispatch(int, namespace=my_namespace) | |||
... def foo(x): | |||
... return x + 1 | |||
Dispatch on instance methods within classes | |||
>>> class MyClass(object): | |||
... @dispatch(list) | |||
... def __init__(self, data): | |||
... self.data = data | |||
... @dispatch(int) | |||
... def __init__(self, datum): | |||
... self.data = [datum] | |||
""" | |||
namespace = kwargs.get("namespace", global_namespace) | |||
types = tuple(types) | |||
def _df(func): | |||
name = func.__name__ | |||
if ismethod(func): | |||
dispatcher = inspect.currentframe().f_back.f_locals.get( | |||
name, MethodDispatcher(name), | |||
) | |||
else: | |||
if name not in namespace: | |||
namespace[name] = Dispatcher(name) | |||
dispatcher = namespace[name] | |||
dispatcher.add(types, func) | |||
return dispatcher | |||
return _df | |||
def ismethod(func): | |||
""" Is func a method? | |||
Note that this has to work as the method is defined but before the class is | |||
defined. At this stage methods look like functions. | |||
""" | |||
if hasattr(inspect, "signature"): | |||
signature = inspect.signature(func) | |||
return signature.parameters.get("self", None) is not None | |||
else: | |||
if sys.version_info.major < 3: | |||
spec = inspect.getargspec(func) | |||
else: | |||
spec = inspect.getfullargspec(func) | |||
return spec and spec.args and spec.args[0] == "self" |
@@ -0,0 +1,401 @@ | |||
import copy | |||
import inspect | |||
import itertools as itl | |||
from warnings import warn | |||
from ..._imperative_rt.dispatcher import Dispatcher as CDispatcher | |||
from .conflict import AmbiguityWarning, ambiguities, ordering, super_signature | |||
from .utils import expand_tuples, parse_union | |||
from .variadic import Variadic, isvariadic | |||
def ambiguity_warn(dispatcher, ambiguities): | |||
""" Raise warning when ambiguity is detected | |||
Parameters | |||
---------- | |||
dispatcher : Dispatcher | |||
The dispatcher on which the ambiguity was detected | |||
ambiguities : set | |||
Set of type signature pairs that are ambiguous within this dispatcher | |||
See Also: | |||
Dispatcher.add | |||
warning_text | |||
""" | |||
warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning) | |||
def variadic_signature_matches_iter(types, full_signature): | |||
"""Check if a set of input types matches a variadic signature. | |||
Notes | |||
----- | |||
The algorithm is as follows: | |||
Initialize the current signature to the first in the sequence | |||
For each type in `types`: | |||
If the current signature is variadic | |||
If the type matches the signature | |||
yield True | |||
Else | |||
Try to get the next signature | |||
If no signatures are left we can't possibly have a match | |||
so yield False | |||
Else | |||
yield True if the type matches the current signature | |||
Get the next signature | |||
""" | |||
sigiter = iter(full_signature) | |||
sig = next(sigiter) | |||
for typ in types: | |||
matches = issubclass(typ, sig) | |||
yield matches | |||
if not isvariadic(sig): | |||
# we're not matching a variadic argument, so move to the next | |||
# element in the signature | |||
sig = next(sigiter) | |||
else: | |||
try: | |||
sig = next(sigiter) | |||
except StopIteration: | |||
assert isvariadic(sig) | |||
yield True | |||
else: | |||
# We have signature items left over, so all of our arguments | |||
# haven't matched | |||
yield False | |||
def variadic_signature_matches(types, full_signature): | |||
# No arguments always matches a variadic signature | |||
assert full_signature | |||
return all(variadic_signature_matches_iter(types, full_signature)) | |||
def get_func_signature(function): | |||
sig = inspect.signature(function) | |||
types = [] | |||
for p in sig.parameters.values(): | |||
ann = p.annotation | |||
ann = parse_union(ann) or ann | |||
if p.kind in ( | |||
inspect.Parameter.POSITIONAL_ONLY, | |||
inspect.Parameter.POSITIONAL_OR_KEYWORD, | |||
): | |||
types.append(ann) | |||
if p.kind == inspect.Parameter.VAR_POSITIONAL: | |||
types.append([ann]) | |||
return tuple(types) | |||
class Frame: | |||
__slots__ = "args", "types", "mro", "mro_offset" | |||
class Dispatcher(CDispatcher): | |||
""" Dispatch methods based on type signature | |||
Use ``dispatch`` to add implementations | |||
Examples | |||
-------- | |||
>>> from multipledispatch import dispatch | |||
>>> @dispatch(int) | |||
... def f(x): | |||
... return x + 1 | |||
>>> @dispatch(float) | |||
... def f(x): | |||
... return x - 1 | |||
>>> f(3) | |||
4 | |||
>>> f(3.0) | |||
2.0 | |||
""" | |||
__slots__ = "__name__", "name", "funcs", "_ordering", "doc" | |||
def __init__(self, name, doc=None): | |||
self.name = self.__name__ = name | |||
self.funcs = {} | |||
self.doc = doc | |||
def register(self, *types, **kwargs): | |||
""" register dispatcher with new implementation | |||
>>> f = Dispatcher('f') | |||
>>> @f.register(int) | |||
... def inc(x): | |||
... return x + 1 | |||
>>> @f.register(float) | |||
... def dec(x): | |||
... return x - 1 | |||
>>> @f.register(list) | |||
... @f.register(tuple) | |||
... def reverse(x): | |||
... return x[::-1] | |||
>>> f(1) | |||
2 | |||
>>> f(1.0) | |||
0.0 | |||
>>> f([1, 2, 3]) | |||
[3, 2, 1] | |||
""" | |||
def _df(func): | |||
self.add(types, func, **kwargs) | |||
return func | |||
return _df | |||
def add(self, signature, func): | |||
""" Add new types/method pair to dispatcher | |||
>>> D = Dispatcher('add') | |||
>>> D.add((int, int), lambda x, y: x + y) | |||
>>> D.add((float, float), lambda x, y: x + y) | |||
>>> D(1, 2) | |||
3 | |||
>>> D(1, 2.0) | |||
Traceback (most recent call last): | |||
... | |||
NotImplementedError: Could not find signature for add: <int, float> | |||
When ``add`` detects a warning it calls the ``on_ambiguity`` callback | |||
with a dispatcher/itself, and a set of ambiguous type signature pairs | |||
as inputs. See ``ambiguity_warn`` for an example. | |||
""" | |||
# Handle annotations | |||
if not signature: | |||
signature = get_func_signature(func) | |||
# Handle union types | |||
if any(isinstance(typ, tuple) for typ in signature): | |||
for typs in expand_tuples(signature): | |||
self.add(typs, func) | |||
return | |||
new_signature = [] | |||
for index, typ in enumerate(signature, start=1): | |||
if not isinstance(typ, (type, list)): | |||
str_sig = ", ".join( | |||
c.__name__ if isinstance(c, type) else str(c) for c in signature | |||
) | |||
raise TypeError( | |||
"Tried to dispatch on non-type: %s\n" | |||
"In signature: <%s>\n" | |||
"In function: %s" % (typ, str_sig, self.name) | |||
) | |||
# handle variadic signatures | |||
if isinstance(typ, list): | |||
if index != len(signature): | |||
raise TypeError("Variadic signature must be the last element") | |||
if len(typ) != 1: | |||
raise TypeError( | |||
"Variadic signature must contain exactly one element. " | |||
"To use a variadic union type place the desired types " | |||
"inside of a tuple, e.g., [(int, str)]" | |||
) | |||
new_signature.append(Variadic[typ[0]]) | |||
else: | |||
new_signature.append(typ) | |||
l = self.funcs.setdefault(tuple(new_signature), []) | |||
for i in l: | |||
if i is func: | |||
raise ValueError("already registered") | |||
l.append(func) | |||
self.enable(func) | |||
self.clear_cache() | |||
try: | |||
del self._ordering | |||
except AttributeError: | |||
pass | |||
@property | |||
def ordering(self): | |||
try: | |||
return self._ordering | |||
except AttributeError: | |||
return self.reorder() | |||
def reorder(self, on_ambiguity=ambiguity_warn): | |||
self._ordering = od = ordering(self.funcs) | |||
amb = ambiguities(self.funcs) | |||
if amb: | |||
on_ambiguity(self, amb) | |||
return od | |||
def __str__(self): | |||
return "<dispatched %s>" % self.name | |||
__repr__ = __str__ | |||
def dispatch(self, *types): | |||
"""Deterimine appropriate implementation for this type signature | |||
This method is internal. Users should call this object as a function. | |||
Implementation resolution occurs within the ``__call__`` method. | |||
>>> from multipledispatch import dispatch | |||
>>> @dispatch(int) | |||
... def inc(x): | |||
... return x + 1 | |||
>>> implementation = inc.dispatch(int) | |||
>>> implementation(3) | |||
4 | |||
>>> print(inc.dispatch(float)) | |||
None | |||
See Also: | |||
``multipledispatch.conflict`` - module to determine resolution order | |||
""" | |||
if types in self.funcs: | |||
return self.funcs[types][-1] | |||
for f in self.dispatch_iter(*types): | |||
return f | |||
def dispatch_iter(self, *types): | |||
n = len(types) | |||
for signature in self.ordering: | |||
if ( | |||
len(signature) == n | |||
and all(map(issubclass, types, signature)) | |||
or len(signature) | |||
and isvariadic(signature[-1]) | |||
and variadic_signature_matches(types, signature) | |||
): | |||
yield from self.funcs[signature][::-1] | |||
def __getstate__(self): | |||
return {"name": self.name, "funcs": self.funcs} | |||
def __setstate__(self, d): | |||
self.name = d["name"] | |||
self.funcs = d["funcs"] | |||
self._ordering = ordering(self.funcs) | |||
self._cache = dict() | |||
@property | |||
def __doc__(self): | |||
docs = ["Multiply dispatched method: %s" % self.name] | |||
if self.doc: | |||
docs.append(self.doc) | |||
other = [] | |||
for sig in self.ordering[::-1]: | |||
funcs = self.funcs[sig] | |||
s = "Inputs: <%s>\n" % str_signature(sig) | |||
sep = "-" * len(s) + "\n" | |||
for i, func in enumerate(funcs): | |||
s += sep | |||
if len(funcs) > 1: | |||
s += "[Handler %d]\n\n" % (i + 1) | |||
if i: | |||
s += "\n\n" | |||
if func.__doc__: | |||
s += func.__doc__.strip() | |||
else: | |||
s += repr(func) + "\n" | |||
docs.append(s) | |||
return "\n\n".join(docs) | |||
def _help(self, *args): | |||
return self.dispatch(*map(type, args)).__doc__ | |||
def help(self, *args, **kwargs): | |||
""" Print docstring for the function corresponding to inputs """ | |||
print(self._help(*args)) | |||
def _source(self, *args): | |||
func = self.dispatch(*map(type, args)) | |||
if not func: | |||
raise TypeError("No function found") | |||
return source(func) | |||
def source(self, *args, **kwargs): | |||
""" Print source code for the function corresponding to inputs """ | |||
print(self._source(*args)) | |||
def source(func): | |||
s = "File: %s\n\n" % inspect.getsourcefile(func) | |||
s = s + inspect.getsource(func) | |||
return s | |||
class MethodDispatcher(Dispatcher): | |||
""" Dispatch methods based on type signature | |||
See Also: | |||
Dispatcher | |||
""" | |||
__slots__ = ("obj", "cls") | |||
@classmethod | |||
def get_func_params(cls, func): | |||
if hasattr(inspect, "signature"): | |||
sig = inspect.signature(func) | |||
return itl.islice(sig.parameters.values(), 1, None) | |||
def __get__(self, instance, owner): | |||
self.obj = instance | |||
self.cls = owner | |||
return self | |||
def __call__(self, *args, **kwargs): | |||
types = tuple([type(arg) for arg in args]) | |||
func = self.dispatch(*types) | |||
if not func: | |||
raise NotImplementedError( | |||
"Could not find signature for %s: <%s>" | |||
% (self.name, str_signature(types)) | |||
) | |||
return func(self.obj, *args, **kwargs) | |||
def str_signature(sig): | |||
""" String representation of type signature | |||
>>> str_signature((int, float)) | |||
'int, float' | |||
""" | |||
return ", ".join(cls.__name__ for cls in sig) | |||
def warning_text(name, amb): | |||
""" The text for ambiguity warnings """ | |||
text = "\nAmbiguities exist in dispatched function %s\n\n" % (name) | |||
text += "The following signatures may result in ambiguous behavior:\n" | |||
for pair in amb: | |||
text += "\t" + ", ".join("[" + str_signature(s) + "]" for s in pair) + "\n" | |||
text += "\n\nConsider making the following additions:\n\n" | |||
text += "\n\n".join( | |||
[ | |||
"@dispatch(" + str_signature(super_signature(s)) + ")\ndef %s(...)" % name | |||
for s in amb | |||
] | |||
) | |||
return text |
@@ -0,0 +1,177 @@ | |||
import sys | |||
import typing | |||
from collections import OrderedDict | |||
def raises(err, lamda): | |||
try: | |||
lamda() | |||
return False | |||
except err: | |||
return True | |||
def expand_tuples(L): | |||
""" | |||
>>> expand_tuples([1, (2, 3)]) | |||
[(1, 2), (1, 3)] | |||
>>> expand_tuples([1, 2]) | |||
[(1, 2)] | |||
""" | |||
if not L: | |||
return [()] | |||
elif not isinstance(L[0], tuple): | |||
rest = expand_tuples(L[1:]) | |||
return [(L[0],) + t for t in rest] | |||
else: | |||
rest = expand_tuples(L[1:]) | |||
return [(item,) + t for t in rest for item in L[0]] | |||
# Taken from theano/theano/gof/sched.py | |||
# Avoids licensing issues because this was written by Matthew Rocklin | |||
def _toposort(edges): | |||
""" Topological sort algorithm by Kahn [1] - O(nodes + vertices) | |||
inputs: | |||
edges - a dict of the form {a: {b, c}} where b and c depend on a | |||
outputs: | |||
L - an ordered list of nodes that satisfy the dependencies of edges | |||
>>> _toposort({1: (2, 3), 2: (3, )}) | |||
[1, 2, 3] | |||
Closely follows the wikipedia page [2] | |||
[1] Kahn, Arthur B. (1962), "Topological sorting of large networks", | |||
Communications of the ACM | |||
[2] http://en.wikipedia.org/wiki/Toposort#Algorithms | |||
""" | |||
incoming_edges = reverse_dict(edges) | |||
incoming_edges = OrderedDict((k, set(val)) for k, val in incoming_edges.items()) | |||
S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges) | |||
L = [] | |||
while S: | |||
n, _ = S.popitem() | |||
L.append(n) | |||
for m in edges.get(n, ()): | |||
assert n in incoming_edges[m] | |||
incoming_edges[m].remove(n) | |||
if not incoming_edges[m]: | |||
S[m] = None | |||
if any(incoming_edges.get(v, None) for v in edges): | |||
raise ValueError("Input has cycles") | |||
return L | |||
def reverse_dict(d): | |||
"""Reverses direction of dependence dict | |||
>>> d = {'a': (1, 2), 'b': (2, 3), 'c':()} | |||
>>> reverse_dict(d) # doctest: +SKIP | |||
{1: ('a',), 2: ('a', 'b'), 3: ('b',)} | |||
:note: dict order are not deterministic. As we iterate on the | |||
input dict, it make the output of this function depend on the | |||
dict order. So this function output order should be considered | |||
as undeterministic. | |||
""" | |||
result = OrderedDict() | |||
for key in d: | |||
for val in d[key]: | |||
result[val] = result.get(val, tuple()) + (key,) | |||
return result | |||
# Taken from toolz | |||
# Avoids licensing issues because this version was authored by Matthew Rocklin | |||
def groupby(func, seq): | |||
""" Group a collection by a key function | |||
>>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank'] | |||
>>> groupby(len, names) # doctest: +SKIP | |||
{3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']} | |||
>>> iseven = lambda x: x % 2 == 0 | |||
>>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP | |||
{False: [1, 3, 5, 7], True: [2, 4, 6, 8]} | |||
See Also: | |||
``countby`` | |||
""" | |||
d = OrderedDict() | |||
for item in seq: | |||
key = func(item) | |||
if key not in d: | |||
d[key] = list() | |||
d[key].append(item) | |||
return d | |||
def typename(type): | |||
"""Get the name of `type`. | |||
Parameters | |||
---------- | |||
type : Union[Type, Tuple[Type]] | |||
Returns | |||
------- | |||
str | |||
The name of `type` or a tuple of the names of the types in `type`. | |||
Examples | |||
-------- | |||
>>> typename(int) | |||
'int' | |||
>>> typename((int, float)) | |||
'(int, float)' | |||
""" | |||
try: | |||
return type.__name__ | |||
except AttributeError: | |||
if len(type) == 1: | |||
return typename(*type) | |||
return "(%s)" % ", ".join(map(typename, type)) | |||
# parse typing.Union | |||
if sys.version_info < (3, 6): | |||
def parse_union(ann): | |||
if type(ann) is not typing.UnionMeta: | |||
return | |||
return ann.__union_params__ | |||
elif sys.version_info < (3, 7): | |||
def parse_union(ann): | |||
if type(ann) is not typing._Union: | |||
return | |||
return ann.__args__ | |||
elif sys.version_info < (3, 8): | |||
def parse_union(ann): | |||
if type(ann) is not typing._GenericAlias: | |||
if type(ann) is not typing.Union: | |||
return | |||
else: | |||
if ann.__origin__ is not typing.Union: | |||
return | |||
return ann.__args__ | |||
else: | |||
def parse_union(ann): | |||
if typing.get_origin(ann) is not typing.Union: | |||
return | |||
return typing.get_args(ann) |
@@ -0,0 +1,95 @@ | |||
from .utils import typename | |||
class VariadicSignatureType(type): | |||
# checking if subclass is a subclass of self | |||
def __subclasscheck__(self, subclass): | |||
other_type = subclass.variadic_type if isvariadic(subclass) else (subclass,) | |||
return subclass is self or all( | |||
issubclass(other, self.variadic_type) for other in other_type | |||
) | |||
def __eq__(self, other): | |||
""" | |||
Return True if other has the same variadic type | |||
Parameters | |||
---------- | |||
other : object (type) | |||
The object (type) to check | |||
Returns | |||
------- | |||
bool | |||
Whether or not `other` is equal to `self` | |||
""" | |||
return isvariadic(other) and set(self.variadic_type) == set(other.variadic_type) | |||
def __hash__(self): | |||
return hash((type(self), frozenset(self.variadic_type))) | |||
def isvariadic(obj): | |||
"""Check whether the type `obj` is variadic. | |||
Parameters | |||
---------- | |||
obj : type | |||
The type to check | |||
Returns | |||
------- | |||
bool | |||
Whether or not `obj` is variadic | |||
Examples | |||
-------- | |||
>>> isvariadic(int) | |||
False | |||
>>> isvariadic(Variadic[int]) | |||
True | |||
""" | |||
return isinstance(obj, VariadicSignatureType) | |||
class VariadicSignatureMeta(type): | |||
"""A metaclass that overrides ``__getitem__`` on the class. This is used to | |||
generate a new type for Variadic signatures. See the Variadic class for | |||
examples of how this behaves. | |||
""" | |||
def __getitem__(self, variadic_type): | |||
if not (isinstance(variadic_type, (type, tuple)) or type(variadic_type)): | |||
raise ValueError( | |||
"Variadic types must be type or tuple of types" | |||
" (Variadic[int] or Variadic[(int, float)]" | |||
) | |||
if not isinstance(variadic_type, tuple): | |||
variadic_type = (variadic_type,) | |||
return VariadicSignatureType( | |||
"Variadic[%s]" % typename(variadic_type), | |||
(), | |||
dict(variadic_type=variadic_type, __slots__=()), | |||
) | |||
class Variadic(metaclass=VariadicSignatureMeta): | |||
"""A class whose getitem method can be used to generate a new type | |||
representing a specific variadic signature. | |||
Examples | |||
-------- | |||
>>> Variadic[int] # any number of int arguments | |||
<class 'multipledispatch.variadic.Variadic[int]'> | |||
>>> Variadic[(int, str)] # any number of one of int or str arguments | |||
<class 'multipledispatch.variadic.Variadic[(int, str)]'> | |||
>>> issubclass(int, Variadic[int]) | |||
True | |||
>>> issubclass(int, Variadic[(int, str)]) | |||
True | |||
>>> issubclass(str, Variadic[(int, str)]) | |||
True | |||
>>> issubclass(float, Variadic[(int, str)]) | |||
False | |||
""" |
@@ -66,13 +66,13 @@ class RawTensor(TensorBase): | |||
delete(self._handle) | |||
@apply.add | |||
@apply.register() | |||
def _(op: OpDef, *args: RawTensor): | |||
outputs = apply_op(op, tuple(i._handle for i in args)) | |||
return tuple(map(RawTensor, outputs)) | |||
@apply.add | |||
@apply.register() | |||
def _(op: Const, *args: RawTensor): | |||
dtype = op.dtype | |||
device = as_device(op.device).to_c() | |||
@@ -79,7 +79,7 @@ def get_context(): | |||
return _context | |||
@apply.add | |||
@apply.register() | |||
def tensor_apply(op: OpBase, *args: Tensor): | |||
data = tuple(i._data if isinstance(i, Tensor) else i for i in args) | |||
# type(Tensor._data) is RawTensor | |||
@@ -46,7 +46,7 @@ __all__ = [ | |||
] | |||
@apply.add | |||
@apply.register() | |||
def _(op: RemoteSend, *args: Tensor): | |||
ret = tensor_apply(op, *args) | |||
@@ -1,5 +1,4 @@ | |||
numpy>=1.18 | |||
multipledispatch==0.6.0 | |||
opencv-python | |||
pyarrow | |||
requests | |||
@@ -0,0 +1,180 @@ | |||
#include "./dispatcher.h" | |||
#include "./pyext17.h" | |||
#include "megbrain/utils/hash.h" | |||
#include "megbrain/utils/small_vector.h" | |||
#include <unordered_map> | |||
#include <structmember.h> | |||
namespace py = pybind11; | |||
namespace pyx = pyext17; | |||
namespace { | |||
struct Handler { | |||
PyObject* func; // borrowed | |||
bool enabled; | |||
Handler() = default; | |||
Handler(PyObject* func_, bool enable = true) : func(func_), enabled(enable) {} | |||
}; | |||
using FastSig = mgb::SmallVector<void*, 8>; | |||
using MRO = std::vector<Handler*>; | |||
struct Frame { | |||
MRO* mro; | |||
size_t mro_offset; | |||
Frame() = default; | |||
Frame(MRO* mro_, size_t mro_offset_ = 0) : mro(mro_), mro_offset(mro_offset_) {} | |||
}; | |||
struct FastSigHash { | |||
size_t operator()(const FastSig& sig) const { | |||
auto* ptr = &sig.front(); | |||
return mgb::XXHash() | |||
.update(ptr, sig.size() * sizeof(FastSig::value_type)) | |||
.digest(); | |||
} | |||
}; | |||
struct ObjectIdHash : std::hash<void*> { | |||
size_t operator()(const py::handle& h) const { | |||
return std::hash<void*>::operator()(h.ptr()); | |||
} | |||
}; | |||
struct Dispatcher { | |||
std::unordered_map<FastSig, std::unique_ptr<MRO>, FastSigHash> cache; | |||
std::vector<Frame> stack; | |||
std::unordered_map<py::object, std::unique_ptr<Handler>, ObjectIdHash> registry; | |||
inline py::handle self() { | |||
return pyx::wrap<Dispatcher>::pycast(this); | |||
} | |||
bool prepare_call(PyObject*const* args, Py_ssize_t nargs) { | |||
FastSig sig(nargs); | |||
for (Py_ssize_t i = 0; i < nargs; ++i) { | |||
sig[i] = Py_TYPE(args[i]); | |||
} | |||
auto it = cache.find(sig); | |||
if (it == cache.end()) { | |||
if (auto mro = resolve(sig)) { | |||
it = cache.emplace(std::move(sig), std::move(mro)).first; | |||
} else { | |||
return false; | |||
} | |||
} | |||
stack.emplace_back(it->second.get()); | |||
return true; | |||
} | |||
template<typename T> | |||
PyObject* do_call(T&& caller) { | |||
auto& frame = stack.back(); | |||
auto& mro = *frame.mro; | |||
auto& i = frame.mro_offset; | |||
for (; i < mro.size(); ++i) { | |||
if (mro[i]->enabled) { | |||
auto ret = caller(mro[i]->func); | |||
if (ret != Py_NotImplemented) { | |||
stack.pop_back(); | |||
return ret; | |||
} | |||
Py_DECREF(ret); | |||
} | |||
} | |||
PyErr_SetString(PyExc_NotImplementedError, "mro exhausted"); | |||
stack.pop_back(); | |||
return nullptr; | |||
} | |||
std::unique_ptr<MRO> resolve(const FastSig& sig) { | |||
try { | |||
py::tuple args(sig.size()); | |||
for (size_t i = 0; i < sig.size(); ++i) { | |||
args[i] = (PyObject*)sig[i]; | |||
} | |||
auto mro_iter = self().attr("dispatch_iter")(*args); | |||
auto ret = std::make_unique<MRO>(); | |||
for (auto i : mro_iter) { | |||
auto it = registry.find(py::reinterpret_borrow<py::object>(i)); | |||
if (it == registry.end()) { | |||
PyErr_SetString(PyExc_RuntimeError, "resolved to unregistered function"); | |||
return nullptr; | |||
} | |||
ret->push_back(it->second.get()); | |||
} | |||
return ret; | |||
} catch (py::error_already_set& e) { | |||
e.restore(); | |||
} catch (std::runtime_error& e) { | |||
PyErr_SetString(PyExc_RuntimeError, e.what()); | |||
} | |||
return nullptr; | |||
} | |||
public: | |||
static constexpr auto tp_name = "Dispatcher"; | |||
PyObject* tp_vectorcall(PyObject*const* args, Py_ssize_t nargs) { | |||
if (!prepare_call(args, nargs)) return nullptr; | |||
return do_call([=](PyObject* func){return _PyObject_FastCall(func, args, nargs);}); | |||
} | |||
PyObject* tp_call(PyObject* args, PyObject* kwargs) { | |||
if (!prepare_call(&PyTuple_GET_ITEM(args, 0), PyTuple_GET_SIZE(args))) return nullptr; | |||
return do_call([=](PyObject* func){return PyObject_Call(func, args, kwargs);}); | |||
} | |||
PyObject* super(PyObject*const* args, Py_ssize_t nargs) { | |||
if (stack.empty()) { | |||
PyErr_SetString(PyExc_RuntimeError, "super called at top level"); | |||
return nullptr; | |||
} | |||
stack.emplace_back(stack.back()).mro_offset++; | |||
return do_call([=](PyObject* func){return _PyObject_FastCall(func, args, nargs);}); | |||
} | |||
void enable(PyObject* func) { | |||
auto obj = py::reinterpret_borrow<py::object>(func); | |||
auto it = registry.find(obj); | |||
if (it != registry.end()) { | |||
it->second->enabled = true; | |||
} else { | |||
registry.emplace(std::move(obj), std::make_unique<Handler>(func)); | |||
} | |||
} | |||
PyObject* disable(PyObject* func) { | |||
auto obj = py::reinterpret_borrow<py::object>(func); | |||
auto it = registry.find(obj); | |||
if (it == registry.end()) { | |||
PyErr_SetString(PyExc_ValueError, "function not registered"); | |||
return nullptr; | |||
} else { | |||
it->second->enabled = false; | |||
} | |||
Py_RETURN_NONE; | |||
} | |||
void clear_cache() { | |||
cache.clear(); | |||
} | |||
}; | |||
} // namespace | |||
void init_dispatcher(py::module m) { | |||
auto* dispatcher_type = pyx::wrap<Dispatcher>::type() | |||
.def<&Dispatcher::enable>("enable") | |||
.def<&Dispatcher::disable>("disable") | |||
.def<&Dispatcher::clear_cache>("clear_cache") | |||
.def<&Dispatcher::tp_vectorcall>("call") | |||
.def<&Dispatcher::super>("super") | |||
.finalize(); | |||
if (!dispatcher_type) throw py::error_already_set(); | |||
m.attr("Dispatcher") = dispatcher_type; | |||
} |
@@ -0,0 +1,5 @@ | |||
#pragma once | |||
#include <pybind11/pybind11.h> | |||
void init_dispatcher(pybind11::module); |
@@ -21,6 +21,8 @@ | |||
#include "./graph_rt.h" | |||
#include "./ops.h" | |||
#include "./dispatcher.h" | |||
namespace py = pybind11; | |||
#ifndef MODULE_NAME | |||
@@ -63,4 +65,6 @@ PYBIND11_MODULE(MODULE_NAME, m) { | |||
from .graph import * | |||
)", | |||
py::getattr(m, "__dict__")); | |||
init_dispatcher(submodule(m, "dispatcher")); | |||
} |
@@ -0,0 +1,270 @@ | |||
#pragma once | |||
#include <stdexcept> | |||
#include <vector> | |||
#include <utility> | |||
#include <Python.h> | |||
namespace pyext17 { | |||
#ifdef METH_FASTCALL | |||
constexpr bool has_fastcall = true; | |||
#else | |||
constexpr bool has_fastcall = false; | |||
#endif | |||
template<typename... Args> | |||
struct invocable_with { | |||
template<typename T> | |||
constexpr bool operator()(T&& lmb) { | |||
return std::is_invocable_v<T, Args...>; | |||
} | |||
}; | |||
#define HAS_MEMBER_TYPE(T, U) invocable_with<T>{}([](auto&& x) -> typename std::decay_t<decltype(x)>::U {}) | |||
#define HAS_MEMBER(T, m) invocable_with<T>{}([](auto&& x) -> decltype(&std::decay_t<decltype(x)>::m) {}) | |||
inline PyObject* cvt_retval(PyObject* rv) { | |||
return rv; | |||
} | |||
#define CVT_RET_PYOBJ(...) \ | |||
if constexpr (std::is_same_v<decltype(__VA_ARGS__), void>) { \ | |||
__VA_ARGS__; \ | |||
Py_RETURN_NONE; \ | |||
} else { \ | |||
return cvt_retval(__VA_ARGS__); \ | |||
} | |||
template <typename T> | |||
struct wrap { | |||
private: | |||
typedef wrap<T> wrap_t; | |||
public: | |||
PyObject_HEAD | |||
std::aligned_storage_t<sizeof(T), alignof(T)> storage; | |||
inline T* inst() { | |||
return reinterpret_cast<T*>(&storage); | |||
} | |||
inline static PyObject* pycast(T* ptr) { | |||
return (PyObject*)((char*)ptr - offsetof(wrap_t, storage)); | |||
} | |||
private: | |||
// method wrapper | |||
enum struct meth_type { | |||
noarg, | |||
varkw, | |||
fastcall, | |||
singarg | |||
}; | |||
template<auto f> | |||
struct detect_meth_type { | |||
static constexpr meth_type value = []() { | |||
using F = decltype(f); | |||
static_assert(std::is_member_function_pointer_v<F>); | |||
if constexpr (std::is_invocable_v<F, T>) { | |||
return meth_type::noarg; | |||
} else if constexpr (std::is_invocable_v<F, T, PyObject*, PyObject*>) { | |||
return meth_type::varkw; | |||
} else if constexpr (std::is_invocable_v<F, T, PyObject*const*, Py_ssize_t>) { | |||
return meth_type::fastcall; | |||
} else if constexpr (std::is_invocable_v<F, T, PyObject*>) { | |||
return meth_type::singarg; | |||
} else { | |||
static_assert(!std::is_same_v<F, F>); | |||
} | |||
}(); | |||
}; | |||
template<meth_type, auto f> | |||
struct meth {}; | |||
template<auto f> | |||
struct meth<meth_type::noarg, f> { | |||
static constexpr int flags = METH_NOARGS; | |||
static PyObject* impl(PyObject* self, PyObject*) { | |||
auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||
CVT_RET_PYOBJ((inst->*f)()); | |||
} | |||
}; | |||
template<auto f> | |||
struct meth<meth_type::varkw, f> { | |||
static constexpr int flags = METH_VARARGS | METH_KEYWORDS; | |||
static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) { | |||
auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||
CVT_RET_PYOBJ((inst->*f)(args, kwargs)); | |||
} | |||
}; | |||
template<auto f> | |||
struct meth<meth_type::fastcall, f> { | |||
#ifdef METH_FASTCALL | |||
static constexpr int flags = METH_FASTCALL; | |||
static PyObject* impl(PyObject* self, PyObject*const* args, Py_ssize_t nargs) { | |||
auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||
CVT_RET_PYOBJ((inst->*f)(args, nargs)); | |||
} | |||
#else | |||
static constexpr int flags = METH_VARARGS; | |||
static PyObject* impl(PyObject* self, PyObject* args) { | |||
auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||
auto* arr = &PyTuple_GET_ITEM(args, 0); | |||
auto size = PyTuple_GET_SIZE(args); | |||
CVT_RET_PYOBJ((inst->*f)(arr, size)); | |||
} | |||
#endif | |||
}; | |||
template<auto f> | |||
struct meth<meth_type::singarg, f> { | |||
static constexpr int flags = METH_O; | |||
static PyObject* impl(PyObject* self, PyObject* obj) { | |||
auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||
CVT_RET_PYOBJ((inst->*f)(obj)); | |||
} | |||
}; | |||
template<auto f> | |||
static constexpr PyMethodDef make_meth_def(const char* name, const char* doc = nullptr) { | |||
using M = meth<detect_meth_type<f>::value, f>; | |||
return {name, (PyCFunction)M::impl, M::flags, doc}; | |||
} | |||
// polyfills | |||
struct tp_new { | |||
static constexpr bool provided = HAS_MEMBER(T, tp_new); | |||
static constexpr bool varkw = std::is_constructible_v<T, PyObject*, PyObject*>; | |||
static constexpr bool noarg = std::is_default_constructible_v<T>; | |||
template<typename = void> | |||
static PyObject* impl(PyTypeObject* type, PyObject* args, PyObject* kwargs) { | |||
auto* self = type->tp_alloc(type, 0); | |||
auto* ptr = reinterpret_cast<wrap_t*>(self)->inst(); | |||
if constexpr (varkw) { | |||
new(ptr) T(args, kwargs); | |||
} else { | |||
new(ptr) T(); | |||
} | |||
return self; | |||
} | |||
static constexpr newfunc value = []() {if constexpr (provided) return T::tp_new; | |||
else if constexpr (varkw || noarg) return impl<>; | |||
else return nullptr;}(); | |||
}; | |||
struct tp_dealloc { | |||
static constexpr bool provided = HAS_MEMBER(T, tp_dealloc); | |||
template<typename = void> | |||
static void impl(PyObject* self) { | |||
reinterpret_cast<wrap_t*>(self)->inst()->~T(); | |||
Py_TYPE(self)->tp_free(self); | |||
} | |||
static constexpr destructor value = []() {if constexpr (provided) return T::tp_dealloc; | |||
else return impl<>;}(); | |||
}; | |||
struct tp_call { | |||
static constexpr bool valid = HAS_MEMBER(T, tp_call); | |||
static constexpr bool static_form = invocable_with<T, PyObject*, PyObject*, PyObject*>{}( | |||
[](auto&& t, auto... args) -> decltype(std::decay_t<decltype(t)>::tp_call(args...)) {}); | |||
template<typename = void> | |||
static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) { | |||
auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||
CVT_RET_PYOBJ(inst->tp_call(args, kwargs)); | |||
} | |||
static constexpr ternaryfunc value = []() {if constexpr (static_form) return T::tp_call; | |||
else if constexpr (valid) return impl<>; | |||
else return nullptr;}(); | |||
}; | |||
public: | |||
class TypeBuilder { | |||
std::vector<PyMethodDef> m_methods; | |||
PyTypeObject m_type; | |||
bool m_finalized = false; | |||
bool m_ready = false; | |||
void check_finalized() { | |||
if (m_finalized) { | |||
throw std::runtime_error("type is already finalized"); | |||
} | |||
} | |||
public: | |||
TypeBuilder(const TypeBuilder&) = delete; | |||
TypeBuilder& operator=(const TypeBuilder&) = delete; | |||
TypeBuilder() : m_type{PyVarObject_HEAD_INIT(nullptr, 0)} { | |||
// static_assert(HAS_MEMBER(T, tp_name)); | |||
if constexpr (HAS_MEMBER(T, tp_name)) { | |||
m_type.tp_name = T::tp_name; | |||
} | |||
m_type.tp_dealloc = tp_dealloc::value; | |||
m_type.tp_call = tp_call::value; | |||
m_type.tp_basicsize = sizeof(wrap_t); | |||
m_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
m_type.tp_new = tp_new::value; | |||
} | |||
PyTypeObject* operator->() { | |||
return &m_type; | |||
} | |||
bool ready() const { | |||
return m_ready; | |||
} | |||
PyObject* finalize() { | |||
if (!m_finalized) { | |||
if (m_methods.size()) { | |||
m_methods.push_back({0}); | |||
if (m_type.tp_methods) { | |||
PyErr_SetString(PyExc_SystemError, "tp_method is already set"); | |||
return nullptr; | |||
} | |||
m_type.tp_methods = &m_methods[0]; | |||
} | |||
if (PyType_Ready(&m_type)) { | |||
return nullptr; | |||
} | |||
m_ready = true; | |||
} | |||
return (PyObject*)&m_type; | |||
} | |||
template<auto f> | |||
TypeBuilder& def(const char* name, const char* doc = nullptr) { | |||
check_finalized(); | |||
m_methods.push_back(make_meth_def<f>(name, doc)); | |||
return *this; | |||
} | |||
}; | |||
static TypeBuilder& type() { | |||
static TypeBuilder type_helper; | |||
return type_helper; | |||
} | |||
}; | |||
} // namespace pyext17 | |||
#undef HAS_MEMBER_TYPE | |||
#undef HAS_MEMBER | |||
#undef CVT_RET_PYOBJ |
@@ -0,0 +1,58 @@ | |||
from megengine.core.tensor.multipledispatch import Dispatcher | |||
def test_register_many(): | |||
f = Dispatcher("f") | |||
log = [] | |||
@f.register() | |||
def _(x: int): | |||
log.append("a") | |||
return log[-1] | |||
@f.register() | |||
def _(x: int): | |||
log.append("b") | |||
return log[-1] | |||
assert f(0) == "b" | |||
assert log == ["b"] | |||
def test_return_not_implemented(): | |||
f = Dispatcher("f") | |||
log = [] | |||
@f.register() | |||
def _(x: int): | |||
log.append("a") | |||
return log[-1] | |||
@f.register() | |||
def _(x: int): | |||
log.append("b") | |||
return NotImplemented | |||
assert f(0) == "a" | |||
assert log == ["b", "a"] | |||
def test_super(): | |||
f = Dispatcher("f") | |||
log = [] | |||
@f.register() | |||
def _(x: int): | |||
log.append("a") | |||
return log[-1] | |||
@f.register() | |||
def _(x: int): | |||
log.append("b") | |||
return f.super(x) | |||
assert f(0) == "a" | |||
assert log == ["b", "a"] |