@@ -52,14 +52,11 @@ class FlatBuffersWriter(IndentWriterBase): | |||||
name = p + e | name = p + e | ||||
e = self._enums[(p, e)] | e = self._enums[(p, e)] | ||||
self._write_doc(e.name) | self._write_doc(e.name) | ||||
self._write("enum %s%s : uint {", p, e.name, indent=1) | |||||
attribute = "(bit_flags)" if e.combined else "" | |||||
self._write("enum %s%s : uint %s {", p, e.name, attribute, indent=1) | |||||
for idx, member in enumerate(e.members): | for idx, member in enumerate(e.members): | ||||
self._write_doc(member) | self._write_doc(member) | ||||
if e.combined: | |||||
self._write("%s=%d,", scramble_enum_member_name(str(member)), | |||||
1<<idx) | |||||
else: | |||||
self._write("%s,", scramble_enum_member_name(str(member))) | |||||
self._write("%s,", scramble_enum_member_name(str(member))) | |||||
self._write("}\n", indent=-1) | self._write("}\n", indent=-1) | ||||
def _write_doc(self, doc): | def _write_doc(self, doc): | ||||
@@ -97,8 +94,11 @@ class FlatBuffersWriter(IndentWriterBase): | |||||
return | return | ||||
self._write_doc(e.name) | self._write_doc(e.name) | ||||
self._used_enum.add(key) | self._used_enum.add(key) | ||||
self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, | |||||
scramble_enum_member_name(str(e.members[e.default]))) | |||||
if e.combined: | |||||
default = e.compose_combined_enum(e.default) | |||||
else: | |||||
default = scramble_enum_member_name(str(e.members[e.default])) | |||||
self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, default) | |||||
def _resolve_const(self, v): | def _resolve_const(self, v): | ||||
while v in self._cur_const_val: | while v in self._cur_const_val: | ||||
@@ -120,9 +120,12 @@ class FlatBuffersWriter(IndentWriterBase): | |||||
return | return | ||||
self._used_enum.add((e.src_class, e.src_name)) | self._used_enum.add((e.src_class, e.src_name)) | ||||
enum_name = e.src_class + e.src_name | enum_name = e.src_class + e.src_name | ||||
self._write( | |||||
"%s:%s = %s;", e.name_field, enum_name, | |||||
scramble_enum_member_name(str(e.src_enum.members[e.get_default()]))) | |||||
s = e.src_enum | |||||
if s.combined: | |||||
default = s.compose_combined_enum(e.get_default()) | |||||
else: | |||||
default = scramble_enum_member_name(str(s.members[e.get_default()])) | |||||
self._write("%s:%s = %s;", e.name_field, enum_name, default) | |||||
def _get_fb_default(self, cppdefault): | def _get_fb_default(self, cppdefault): | ||||
if not isinstance(cppdefault, str): | if not isinstance(cppdefault, str): | ||||
@@ -73,11 +73,21 @@ class member_defs: | |||||
"""define an enum; the result would contain both an enum class def and its | """define an enum; the result would contain both an enum class def and its | ||||
corresponding data field | corresponding data field | ||||
:param default: index of default member value | |||||
:param default: | |||||
for normal enum class: index of default member value | |||||
for bit combined class: tuple of index of default member value | |||||
For example, following representations of the default value for bit | |||||
combined class are all equivalent: | |||||
Enum(members=('a', 'b', 'c'), default=('a', 'b'), ...) | |||||
Enum(members=('a', 'b', 'c'), default=(0, 1), ...) | |||||
Enum(members=('a', 'b', 'c'), default=(1 << 0) | (1 << 1), ...) | |||||
:attr name_field: name of the data field of this enum in the param | :attr name_field: name of the data field of this enum in the param | ||||
struct | struct | ||||
:attr member_alias: list of (member, alias) pairs | |||||
:attr member_alias: | |||||
for normal enum class: list of (member, alias) pairs | |||||
for bit combined class: list of (tuple of members, alias) paris | |||||
""" | """ | ||||
__slots__ = ['name', 'name_field', 'members', 'default', | __slots__ = ['name', 'name_field', 'members', 'default', | ||||
'member_alias', 'combined'] | 'member_alias', 'combined'] | ||||
@@ -90,17 +100,11 @@ class member_defs: | |||||
name = member_defs.Doc.make(name) | name = member_defs.Doc.make(name) | ||||
assert name.id[0].isupper() | assert name.id[0].isupper() | ||||
members = tuple(map(member_defs.Doc.make, members)) | members = tuple(map(member_defs.Doc.make, members)) | ||||
if isinstance(default, str): | |||||
if default not in name_field: | |||||
raise ValueError( | |||||
"Default value '{}' does not exist.".format(default)) | |||||
default = name_field.index(default) | |||||
assert isinstance(default, int) | |||||
self.name = name | self.name = name | ||||
self.combined = combined | self.combined = combined | ||||
self.name_field = self.get_name_field(name.id, name_field) | self.name_field = self.get_name_field(name.id, name_field) | ||||
self.members = members | self.members = members | ||||
self.default = default | |||||
self.default = self.normalize_enum_value(default) | |||||
self.all_enums[(param_name, name.id)] = self | self.all_enums[(param_name, name.id)] = self | ||||
@@ -114,6 +118,43 @@ class member_defs: | |||||
assert isinstance(name_field, str) | assert isinstance(name_field, str) | ||||
return name_field | return name_field | ||||
def normalize_enum_value(self, value): | |||||
def normalize(v): | |||||
if isinstance(v, str): | |||||
if v not in self.members: | |||||
raise ValueError( | |||||
"enum member '{}' does not exist.".format(v)) | |||||
v = self.members.index(v) | |||||
assert isinstance(v, int) | |||||
return v | |||||
if self.combined: | |||||
if isinstance(value, int): | |||||
value = self.decompose_combined_enum(value) | |||||
assert isinstance(value, tuple) | |||||
value = tuple(normalize(i) for i in value) | |||||
return value | |||||
else: | |||||
return normalize(value) | |||||
@staticmethod | |||||
def decompose_combined_enum(v): | |||||
"""Integer => tuple of the indexes of the enum members""" | |||||
assert isinstance(v, int) | |||||
idx = 0 | |||||
members = [] | |||||
while v > 0: | |||||
if v & 1: | |||||
members.append(idx) | |||||
idx += 1 | |||||
v >>= 1 | |||||
return tuple(members) | |||||
def compose_combined_enum(self, v): | |||||
"""tuple of members => Integer""" | |||||
assert self.combined and isinstance(v, tuple) | |||||
norm_v = self.normalize_enum_value(v) | |||||
return sum(1 << i for i in norm_v) | |||||
class Field(Base): | class Field(Base): | ||||
"""define a normal data field""" | """define a normal data field""" | ||||
__slots__ = ['name', 'dtype', 'default'] | __slots__ = ['name', 'dtype', 'default'] | ||||
@@ -146,6 +187,10 @@ class member_defs: | |||||
src_name = name | src_name = name | ||||
self.src_name = src_name | self.src_name = src_name | ||||
self.default = default | self.default = default | ||||
# TODO: remove this assertion if needed; adding mock param_defs in | |||||
# current testing framework is too complicated, and currently we | |||||
# only allow aliasing of normal enum | |||||
assert not self.src_enum.combined | |||||
@property | @property | ||||
def src_enum(self): | def src_enum(self): | ||||
@@ -157,7 +202,7 @@ class member_defs: | |||||
set""" | set""" | ||||
if self.default is None: | if self.default is None: | ||||
return self.src_enum.default | return self.src_enum.default | ||||
return self.default | |||||
return self.src_enum.normalize_enum_value(self.default) | |||||
class ParamDef: | class ParamDef: | ||||
@@ -198,7 +243,7 @@ class ParamDef: | |||||
self.name.id, name, name_field, members, default, member_alias)) | self.name.id, name, name_field, members, default, member_alias)) | ||||
return self | return self | ||||
def add_bit_combination_enum(self, name, *members, default=0, | |||||
def add_bit_combination_enum(self, name, *members, default=tuple(), | |||||
name_field=None, member_alias=[]): | name_field=None, member_alias=[]): | ||||
self.members.append(member_defs.Enum( | self.members.append(member_defs.Enum( | ||||
self.name.id, name, name_field, members, default, member_alias, True)) | self.name.id, name, name_field, members, default, member_alias, True)) | ||||
@@ -322,11 +367,13 @@ class PyWriter(IndentWriterBase): | |||||
' for idx, v in enumerate(pdata):\n' | ' for idx, v in enumerate(pdata):\n' | ||||
' if isinstance(v, _EnumBase):\n' | ' if isinstance(v, _EnumBase):\n' | ||||
' pdata[idx] = _enum_member2num[id(v)]\n' | ' pdata[idx] = _enum_member2num[id(v)]\n' | ||||
' elif isinstance(v, _BitCombinedEnumBase):\n' | |||||
' pdata[idx] = v._value_\n' | |||||
' return tag + self._packer.pack(*pdata)\n' | ' return tag + self._packer.pack(*pdata)\n' | ||||
'\n' | '\n' | ||||
) | ) | ||||
self._write( | |||||
'class _EnumBase(enum.Enum):\n' | |||||
# it's hard to mix custom implemention into enum, just do copy-paste instead | |||||
classbody = ( | |||||
' @classmethod\n' | ' @classmethod\n' | ||||
' def __normalize(cls, val):\n' | ' def __normalize(cls, val):\n' | ||||
' if isinstance(val, str):\n' | ' if isinstance(val, str):\n' | ||||
@@ -349,6 +396,12 @@ class PyWriter(IndentWriterBase): | |||||
' return super()._missing_(value)\n' | ' return super()._missing_(value)\n' | ||||
'\n' | '\n' | ||||
) | ) | ||||
self._write( | |||||
'class _EnumBase(enum.Enum):\n' + classbody | |||||
) | |||||
self._write( | |||||
'class _BitCombinedEnumBase(enum.Flag):\n' + classbody | |||||
) | |||||
if not self._imperative: | if not self._imperative: | ||||
self._write( | self._write( | ||||
'def _as_dtype_num(dtype):\n' | 'def _as_dtype_num(dtype):\n' | ||||
@@ -464,30 +517,42 @@ class SerializedDType(_ParamDefBase): | |||||
def _on_member_enum(self, e): | def _on_member_enum(self, e): | ||||
qualname = '{}.{}'.format(self._cur_param_name, e.name) | qualname = '{}.{}'.format(self._cur_param_name, e.name) | ||||
self._write('class %s(_EnumBase):', e.name, indent=1) | |||||
if e.combined: | |||||
self._write('class %s(_BitCombinedEnumBase):', e.name, indent=1) | |||||
else: | |||||
self._write('class %s(_EnumBase):', e.name, indent=1) | |||||
self._write_doc(e.name) | self._write_doc(e.name) | ||||
for idx, emem in enumerate(e.members): | for idx, emem in enumerate(e.members): | ||||
self._write('%s = "%s"', emem, emem) | |||||
self._write_doc(emem) | |||||
if e.combined: | if e.combined: | ||||
self._enum_member2num.append('id({}.{}):{}'.format( | |||||
qualname, emem, 1<<idx)) | |||||
self._write('%s = 1 << %d', emem, idx) | |||||
self._write_doc(emem) | |||||
else: | else: | ||||
self._write('%s = "%s"', emem, emem) | |||||
self._write_doc(emem) | |||||
self._enum_member2num.append('id({}.{}):{}'.format( | self._enum_member2num.append('id({}.{}):{}'.format( | ||||
qualname, emem, idx)) | qualname, emem, idx)) | ||||
for emem, emem_alis in e.member_alias: | |||||
self._write('%s = %s', emem_alis, emem) | |||||
for emem, emem_alias in e.member_alias: | |||||
if e.combined: | |||||
self._write('%s = %s', emem_alias, e.compose_combined_enum(emem)) | |||||
else: | |||||
self._write('%s = %s', emem_alias, emem) | |||||
self._unindent() | self._unindent() | ||||
self._write('') | self._write('') | ||||
if e.combined: | |||||
default = e.compose_combined_enum(e.default) | |||||
else: | |||||
default = "'{}'".format(e.members[e.default]) | |||||
self._cur_fields.append(self.FieldDef( | self._cur_fields.append(self.FieldDef( | ||||
name=e.name_field, | name=e.name_field, | ||||
cvt='{}.convert({})'.format(qualname, e.name_field), | cvt='{}.convert({})'.format(qualname, e.name_field), | ||||
fmt='I', | fmt='I', | ||||
default="'{}'".format(e.members[e.default]), | |||||
default=default, | |||||
type=qualname, | type=qualname, | ||||
doc=None)) | doc=None)) | ||||
@@ -495,11 +560,16 @@ class SerializedDType(_ParamDefBase): | |||||
self._write('%s = %s.%s', e.name, e.src_class, e.src_name) | self._write('%s = %s.%s', e.name, e.src_class, e.src_name) | ||||
s = e.src_enum | s = e.src_enum | ||||
qualname = '{}.{}'.format(e.src_class, e.src_name) | qualname = '{}.{}'.format(e.src_class, e.src_name) | ||||
if s.combined: | |||||
default = s.compose_combined_enum(e.get_default()) | |||||
else: | |||||
default = "'{}'".format(s.members[e.get_default()]) | |||||
self._cur_fields.append(self.FieldDef( | self._cur_fields.append(self.FieldDef( | ||||
name=e.name_field, | name=e.name_field, | ||||
cvt='{}.convert({})'.format(qualname, e.name_field), | cvt='{}.convert({})'.format(qualname, e.name_field), | ||||
fmt='I', | fmt='I', | ||||
default="'{}'".format(s.members[e.get_default()]), | |||||
default=default, | |||||
type=qualname, | type=qualname, | ||||
doc=None)) | doc=None)) | ||||
@@ -639,14 +709,19 @@ class CPPWriter(IndentWriterBase): | |||||
v += ',' | v += ',' | ||||
self._write(v) | self._write(v) | ||||
for mem, alias in e.member_alias: | for mem, alias in e.member_alias: | ||||
self._write('%s = %s,', alias, mem) | |||||
if e.combined: | |||||
self._write('%s = %s,', alias, e.compose_combined_enum(mem)) | |||||
else: | |||||
self._write('%s = %s,', alias, mem) | |||||
self._write('};', indent=-1) | self._write('};', indent=-1) | ||||
self._non_static_members.append(e) | self._non_static_members.append(e) | ||||
self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;', | self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;', | ||||
str(e.name).upper(), len(e.members)) | str(e.name).upper(), len(e.members)) | ||||
self._add_ctor_args(e.name, | |||||
'{}::{}'.format(e.name, e.members[e.default]), | |||||
e.name_field) | |||||
if e.combined: | |||||
default = 'static_cast<{}>({})'.format(e.name, e.compose_combined_enum(e.default)) | |||||
else: | |||||
default = '{}::{}'.format(e.name, e.members[e.default]) | |||||
self._add_ctor_args(e.name, default, e.name_field) | |||||
def _on_member_enum_alias(self, e): | def _on_member_enum_alias(self, e): | ||||
s = e.src_enum | s = e.src_enum | ||||
@@ -654,10 +729,11 @@ class CPPWriter(IndentWriterBase): | |||||
self._non_static_members.append(e) | self._non_static_members.append(e) | ||||
self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;', | self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;', | ||||
str(e.name).upper(), len(s.members)) | str(e.name).upper(), len(s.members)) | ||||
self._add_ctor_args(e.name, | |||||
'{}::{}'.format(e.name, | |||||
s.members[e.get_default()]), | |||||
e.name_field) | |||||
if s.combined: | |||||
default = 'static_cast<{}>({})'.format(e.name, s.compose_combined_enum(e.default)) | |||||
else: | |||||
default = '{}::{}'.format(e.name, s.members[e.get_default()]) | |||||
self._add_ctor_args(e.name, default, e.name_field) | |||||
def _on_member_field(self, f): | def _on_member_field(self, f): | ||||
self._non_static_members.append(f) | self._non_static_members.append(f) | ||||
@@ -106,7 +106,12 @@ class ConverterWriter(IndentWriterBase): | |||||
return | return | ||||
# wrapped with default value | # wrapped with default value | ||||
default_val = "static_cast<{}::{}>({})".format(fullname, e.name, e.default) | |||||
if e.combined: | |||||
default_val = "static_cast<{}::{}>({})".format( | |||||
fullname, e.name, e.compose_combined_enum(e.default)) | |||||
else: | |||||
default_val = "{}::{}::{}".format(fullname, e.name, e.members[e.default]) | |||||
wrapped = self._wrapped_with_default_value(td_class, default_val) | wrapped = self._wrapped_with_default_value(td_class, default_val) | ||||
self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) | self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) | ||||
@@ -124,7 +129,13 @@ class ConverterWriter(IndentWriterBase): | |||||
self._write("def {} : {};".format(td_class, enum_def)) | self._write("def {} : {};".format(td_class, enum_def)) | ||||
# wrapped with default value | # wrapped with default value | ||||
default_val = "static_cast<{}::{}>({})".format(fullname, e.name, e.get_default()) | |||||
s = e.src_enum | |||||
if s.combined: | |||||
default_val = "static_cast<{}::{}>({})".format( | |||||
fullname, e.name, s.compose_combined_enum(e.get_default())) | |||||
else: | |||||
default_val = "{}::{}::{}".format(fullname, e.name, s.members[e.get_default()]) | |||||
wrapped = self._wrapped_with_default_value(td_class, default_val) | wrapped = self._wrapped_with_default_value(td_class, default_val) | ||||
self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) | self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) | ||||
@@ -185,7 +185,7 @@ SmallVector<AlgoCategory> ConvBiasImpl::suggest_algo_category_order( | |||||
} | } | ||||
//! conv1x1 | //! conv1x1 | ||||
im2col_prefer |= (FH == 1 && FW == 1); | im2col_prefer |= (FH == 1 && FW == 1); | ||||
//! x86 8x8x16 not optmized, so it will use fallback im2col+matmul | |||||
//! x86 8x8x16 not optimized, so it will use fallback im2col+matmul | |||||
if (param.deduce_algo_data_type() == AlgoDataType::INT8X8X16) { | if (param.deduce_algo_data_type() == AlgoDataType::INT8X8X16) { | ||||
im2col_prefer = true; | im2col_prefer = true; | ||||
} | } | ||||
@@ -8,6 +8,9 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import numpy as np | import numpy as np | ||||
from .._imperative_rt import make_const | |||||
from .._imperative_rt.core2 import SymbolVar, Tensor | |||||
class Const: | class Const: | ||||
def __init__(self, value=None, *, dtype=None, device=None): | def __init__(self, value=None, *, dtype=None, device=None): | ||||
@@ -19,7 +22,19 @@ class Const: | |||||
from ...tensor import Tensor | from ...tensor import Tensor | ||||
device = self.device | device = self.device | ||||
if device is None: | |||||
device = reference[0].device | |||||
if len(reference) != 0: | |||||
reference = reference[0] | |||||
assert isinstance( | |||||
reference, (SymbolVar, Tensor) | |||||
), "Reference should be Tensor or VarNode" | |||||
if device is None: | |||||
device = reference.device | |||||
if isinstance(reference, SymbolVar): | |||||
cls = type(reference) | |||||
rst = cls(make_const(reference.graph, self.value, device, self.dtype)) | |||||
return (rst,) | |||||
return (Tensor(self.value, self.dtype, self.device, True),) | return (Tensor(self.value, self.dtype, self.device, True),) |
@@ -13,7 +13,7 @@ from typing import Union | |||||
import numpy as np | import numpy as np | ||||
from .._imperative_rt.common import CompNode | from .._imperative_rt.common import CompNode | ||||
from .._imperative_rt.core2 import Tensor, apply | |||||
from .._imperative_rt.core2 import SymbolVar, Tensor, apply | |||||
from ..ops import builtin | from ..ops import builtin | ||||
from ..ops.builtin import Elemwise, GetVarShape | from ..ops.builtin import Elemwise, GetVarShape | ||||
from . import utils | from . import utils | ||||
@@ -230,7 +230,9 @@ def _todo(*_): | |||||
def _expand_args(args): | def _expand_args(args): | ||||
if len(args) == 1: | if len(args) == 1: | ||||
if isinstance(args[0], (collections.abc.Sequence, Tensor, np.ndarray),): | |||||
if isinstance( | |||||
args[0], (collections.abc.Sequence, Tensor, SymbolVar, np.ndarray), | |||||
): | |||||
args = args[0] | args = args[0] | ||||
return args | return args | ||||
@@ -5,6 +5,7 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import re | |||||
from collections import namedtuple | from collections import namedtuple | ||||
from typing import Union | from typing import Union | ||||
@@ -22,6 +23,12 @@ from .._imperative_rt.common import ( | |||||
) | ) | ||||
def get_dtype_bit(dtype_name: str): | |||||
numbers = re.findall(r"\d+", dtype_name) | |||||
assert len(numbers) == 1, "Unsupport dtype name with more than one number." | |||||
return int(numbers[0]) | |||||
# normal dtype related | # normal dtype related | ||||
def is_lowbit(dtype): | def is_lowbit(dtype): | ||||
return (dtype is intb1) or (dtype is intb2) or (dtype is intb4) | return (dtype is intb1) or (dtype is intb2) or (dtype is intb4) | ||||
@@ -10,7 +10,7 @@ from typing import Iterable | |||||
import numpy as np | import numpy as np | ||||
from .._imperative_rt.core2 import Tensor, apply | |||||
from .._imperative_rt.core2 import SymbolVar, Tensor, apply | |||||
from .._trace_option import use_symbolic_shape | from .._trace_option import use_symbolic_shape | ||||
from ..ops import builtin | from ..ops import builtin | ||||
from ..ops.special import Const | from ..ops.special import Const | ||||
@@ -149,13 +149,13 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||||
return True | return True | ||||
def get_index(i): | def get_index(i): | ||||
if not isinstance(i, (Tensor)): | |||||
if not isinstance(i, (Tensor, SymbolVar)): | |||||
if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_: | if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_: | ||||
(i,) = Const(i, dtype=np.bool_, device=inp.device)() | |||||
(i,) = Const(i, dtype=np.bool_, device=inp.device)(inp) | |||||
else: | else: | ||||
(i,) = Const(i, dtype=np.int32, device=inp.device)() | |||||
(i,) = Const(i, dtype=np.int32, device=inp.device)(inp) | |||||
return i | return i | ||||
assert isinstance(i, Tensor) | |||||
assert isinstance(i, (Tensor, SymbolVar)) | |||||
if i.dtype != np.bool_: | if i.dtype != np.bool_: | ||||
return i | return i | ||||
_, ind = apply(builtin.CondTake(), i, i) | _, ind = apply(builtin.CondTake(), i, i) | ||||
@@ -197,9 +197,9 @@ def try_condtake(tensor, index): | |||||
): | ): | ||||
return [] | return [] | ||||
if isinstance(index, np.ndarray): | if isinstance(index, np.ndarray): | ||||
(index,) = Const(index, dtype=np.bool_, device=tensor.device)() | |||||
assert isinstance(index, Tensor) | |||||
if not isinstance(tensor, Tensor): | |||||
(index,) = Const(index, dtype=np.bool_, device=tensor.device)(tensor) | |||||
assert isinstance(index, (Tensor, SymbolVar)) | |||||
if not isinstance(tensor, (Tensor, SymbolVar)): | |||||
raise TypeError("input must be a tensor") | raise TypeError("input must be a tensor") | ||||
if tensor.device != index.device: | if tensor.device != index.device: | ||||
raise ValueError( | raise ValueError( | ||||
@@ -214,11 +214,16 @@ def getitem(tensor, index): | |||||
return try_result[0] | return try_result[0] | ||||
tensor, tensors, items, use_subtensor, ret_scalar = unpack_getitem(tensor, index) | tensor, tensors, items, use_subtensor, ret_scalar = unpack_getitem(tensor, index) | ||||
for v in tensors: | for v in tensors: | ||||
if v.shape is None: | |||||
break | |||||
if isinstance(v.shape, v.__class__): | if isinstance(v.shape, v.__class__): | ||||
break | break | ||||
if len(v.shape) > 0 and v.shape[0] == 0: | if len(v.shape) > 0 and v.shape[0] == 0: | ||||
(empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)() | |||||
(empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)( | |||||
tensor | |||||
) | |||||
return empty_tensor | return empty_tensor | ||||
if use_subtensor: | if use_subtensor: | ||||
op = builtin.Subtensor(items=items) | op = builtin.Subtensor(items=items) | ||||
else: | else: | ||||
@@ -235,8 +240,8 @@ def setitem(tensor, index, value): | |||||
if len(try_result) == 2: | if len(try_result) == 2: | ||||
index = try_result[1] | index = try_result[1] | ||||
tensor = tensor.reshape(-1) | tensor = tensor.reshape(-1) | ||||
if not isinstance(value, Tensor): | |||||
(value,) = Const(value, dtype=tensor.dtype, device=tensor.device)() | |||||
if not isinstance(value, (Tensor, SymbolVar)): | |||||
(value,) = Const(value, dtype=tensor.dtype, device=tensor.device)(tensor) | |||||
tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) | tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) | ||||
if use_subtensor: | if use_subtensor: | ||||
op = builtin.Subtensor(items=items) | op = builtin.Subtensor(items=items) | ||||
@@ -11,8 +11,9 @@ from typing import Iterable, Union | |||||
import numpy as np | import numpy as np | ||||
from .._imperative_rt import VarNode | |||||
from .._imperative_rt.core2 import Tensor, apply, dtype_promotion, get_device | |||||
from .._imperative_rt import make_const | |||||
from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion, get_device | |||||
from .._wrap import device as as_device | |||||
from ..ops import builtin | from ..ops import builtin | ||||
from ..ops.special import Const | from ..ops.special import Const | ||||
from .dtype import is_dtype_equal, is_quantize | from .dtype import is_dtype_equal, is_quantize | ||||
@@ -38,13 +39,9 @@ def set_convert_inputs(flag): | |||||
def concatenate(inputs, axis=0, *, device=None): | def concatenate(inputs, axis=0, *, device=None): | ||||
dtype = dtype_promotion(inputs) | |||||
device = get_device(inputs) | |||||
def convert(x): | |||||
return convert_single_value(x, dtype=dtype, device=device) | |||||
inputs = tuple(map(convert, inputs)) | |||||
inputs = convert_inputs(*inputs) | |||||
if device is None: | |||||
device = get_device(inputs) | |||||
(result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inputs) | (result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inputs) | ||||
return result | return result | ||||
@@ -60,7 +57,7 @@ def astype(x, dtype): | |||||
def convert_single_value(v, *, dtype=None, device=None): | def convert_single_value(v, *, dtype=None, device=None): | ||||
if isinstance(v, (Tensor, VarNode)): | |||||
if isinstance(v, (Tensor, SymbolVar)): | |||||
if not is_quantize(v.dtype): | if not is_quantize(v.dtype): | ||||
v = astype(v, dtype) | v = astype(v, dtype) | ||||
else: | else: | ||||
@@ -68,17 +65,35 @@ def convert_single_value(v, *, dtype=None, device=None): | |||||
return v | return v | ||||
def convert_inputs(*args: Tensor): | |||||
def convert_inputs(*args, device=None): | |||||
if not _enable_convert_inputs: | if not _enable_convert_inputs: | ||||
return args | return args | ||||
dtype = dtype_promotion(args) | dtype = dtype_promotion(args) | ||||
device = get_device(args) | |||||
if device is None: | |||||
device = get_device(args) | |||||
device = as_device(device) | |||||
graph = None | |||||
sym_type = None | |||||
for a in args: | |||||
if isinstance(a, SymbolVar): | |||||
if graph is None: | |||||
graph = a.var.graph | |||||
sym_type = type(a) | |||||
else: | |||||
assert graph == a.var.graph | |||||
args = list(args) | |||||
if graph is not None: | |||||
for i in range(len(args)): | |||||
if not isinstance(args[i], SymbolVar): | |||||
rst = make_const(graph, np.array(args[i]), device.to_c(), dtype) | |||||
args[i] = sym_type(rst) | |||||
def convert(value): | def convert(value): | ||||
if value is None: | if value is None: | ||||
return value | return value | ||||
return convert_single_value(value, dtype=dtype, device=device) | |||||
return convert_single_value(value, dtype=dtype, device=device.to_c()) | |||||
return tuple(map(convert, args)) | return tuple(map(convert, args)) | ||||
@@ -98,14 +113,14 @@ def result_type(*args): | |||||
def isscalar(x): | def isscalar(x): | ||||
if isinstance(x, Tensor): | |||||
if isinstance(x, (Tensor, SymbolVar)): | |||||
return x._isscalar() | return x._isscalar() | ||||
return np.isscalar(x) | return np.isscalar(x) | ||||
def setscalar(x): | def setscalar(x): | ||||
if isinstance(x, Tensor): | |||||
if isinstance(x, (Tensor, SymbolVar)): | |||||
x._setscalar() | x._setscalar() | ||||
else: | else: | ||||
raise NotImplementedError("Unsupport type {}".format(type(x))) | raise NotImplementedError("Unsupport type {}".format(type(x))) | ||||
@@ -132,7 +147,7 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||||
if not isinstance(x, collections.abc.Sequence): | if not isinstance(x, collections.abc.Sequence): | ||||
raise TypeError | raise TypeError | ||||
if any(isinstance(i, Tensor) for i in x): | |||||
if any(isinstance(i, (Tensor, SymbolVar)) for i in x): | |||||
x = concatenate(x, device=device) | x = concatenate(x, device=device) | ||||
if dtype is not None: | if dtype is not None: | ||||
x = astype(x, dtype) | x = astype(x, dtype) | ||||
@@ -142,7 +157,7 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||||
def _expand_int(s, i): | def _expand_int(s, i): | ||||
if isinstance(i, Tensor): | |||||
if isinstance(i, (Tensor, SymbolVar)): | |||||
i_np = i.numpy() | i_np = i.numpy() | ||||
if i_np.ndim == 0: | if i_np.ndim == 0: | ||||
s.append(int(i_np)) | s.append(int(i_np)) | ||||
@@ -40,7 +40,7 @@ def set_execution_strategy(option): | |||||
* HEURISTIC uses heuristic to choose the fastest algorithm. | * HEURISTIC uses heuristic to choose the fastest algorithm. | ||||
* PROFILE runs possible algorithms on real device to find the best one. | * PROFILE runs possible algorithms on real device to find the best one. | ||||
* REPRODUCIBLE uses the algorithms that is reproducible. | * REPRODUCIBLE uses the algorithms that is reproducible. | ||||
* OPTMIZED uses the algorithms that is optimized. | |||||
* OPTIMIZED uses the algorithms that is optimized. | |||||
The default strategy is HEURISTIC, this options can be combined to | The default strategy is HEURISTIC, this options can be combined to | ||||
form a combination option, e.g. PROFILE | REPRODUCIBLE | form a combination option, e.g. PROFILE | REPRODUCIBLE | ||||
@@ -9,8 +9,7 @@ | |||||
# pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order | # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order | ||||
import numpy as np | import numpy as np | ||||
from ..core._imperative_rt.core2 import apply | |||||
from ..core._imperative_rt.graph import VarNode | |||||
from ..core._imperative_rt.core2 import SymbolVar, apply | |||||
from ..core.ops import builtin | from ..core.ops import builtin | ||||
from ..core.ops.builtin import Elemwise | from ..core.ops.builtin import Elemwise | ||||
from ..core.tensor import utils | from ..core.tensor import utils | ||||
@@ -72,7 +71,7 @@ __all__ = [ | |||||
def _elwise(*args, mode): | def _elwise(*args, mode): | ||||
tensor_args = list(filter(lambda x: isinstance(x, (Tensor, VarNode)), args)) | |||||
tensor_args = list(filter(lambda x: isinstance(x, (Tensor, SymbolVar)), args)) | |||||
if len(tensor_args) == 0: | if len(tensor_args) == 0: | ||||
dtype = utils.dtype_promotion(args) | dtype = utils.dtype_promotion(args) | ||||
first_arg = Tensor(args[0], dtype=dtype, device=get_default_device()) | first_arg = Tensor(args[0], dtype=dtype, device=get_default_device()) | ||||
@@ -12,7 +12,7 @@ from typing import Iterable, Optional, Sequence, Union | |||||
import numpy as np | import numpy as np | ||||
from ..core._imperative_rt import CompNode | from ..core._imperative_rt import CompNode | ||||
from ..core._imperative_rt.core2 import apply | |||||
from ..core._imperative_rt.core2 import SymbolVar, apply | |||||
from ..core._wrap import device as as_device | from ..core._wrap import device as as_device | ||||
from ..core.ops import builtin | from ..core.ops import builtin | ||||
from ..core.ops.builtin import Copy, Identity | from ..core.ops.builtin import Copy, Identity | ||||
@@ -101,7 +101,7 @@ def eye(N, M=None, *, dtype="float32", device: Optional[CompNode] = None) -> Ten | |||||
return result | return result | ||||
def full(shape, value, dtype="float32", device=None): | |||||
def full(shape, value, dtype="float32", device=None) -> Tensor: | |||||
""" | """ | ||||
Returns a tensor with given shape and value. | Returns a tensor with given shape and value. | ||||
""" | """ | ||||
@@ -115,7 +115,7 @@ def full(shape, value, dtype="float32", device=None): | |||||
return broadcast_to(x, shape) | return broadcast_to(x, shape) | ||||
def ones(shape, dtype="float32", device=None): | |||||
def ones(shape, dtype="float32", device=None) -> Tensor: | |||||
""" | """ | ||||
Returns a ones tensor with given shape. | Returns a ones tensor with given shape. | ||||
@@ -142,14 +142,14 @@ def ones(shape, dtype="float32", device=None): | |||||
return full(shape, 1.0, dtype=dtype, device=device) | return full(shape, 1.0, dtype=dtype, device=device) | ||||
def zeros(shape, dtype="float32", device=None): | |||||
def zeros(shape, dtype="float32", device=None) -> Tensor: | |||||
""" | """ | ||||
Returns a zero tensor with given shape. | Returns a zero tensor with given shape. | ||||
""" | """ | ||||
return full(shape, 0.0, dtype=dtype, device=device) | return full(shape, 0.0, dtype=dtype, device=device) | ||||
def zeros_like(inp: Tensor) -> Tensor: | |||||
def zeros_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]: | |||||
""" | """ | ||||
Returns a zero tensor with the same shape as input tensor. | Returns a zero tensor with the same shape as input tensor. | ||||
@@ -176,21 +176,26 @@ def zeros_like(inp: Tensor) -> Tensor: | |||||
[0 0 0]] | [0 0 0]] | ||||
""" | """ | ||||
return zeros(inp.shape, dtype=inp.dtype, device=inp.device) | |||||
return full_like(inp, 0.0) | |||||
def ones_like(inp: Tensor) -> Tensor: | |||||
def ones_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]: | |||||
""" | """ | ||||
Returns a ones tensor with the same shape as input tensor. | Returns a ones tensor with the same shape as input tensor. | ||||
""" | """ | ||||
return ones(inp.shape, dtype=inp.dtype, device=inp.device) | |||||
return full_like(inp, 1.0) | |||||
def full_like(inp: Tensor, value: Union[int, float]) -> Tensor: | |||||
def full_like( | |||||
inp: Union[Tensor, SymbolVar], value: Union[int, float] | |||||
) -> Union[Tensor, SymbolVar]: | |||||
""" | """ | ||||
Returns a tensor filled with given value with the same shape as input tensor. | Returns a tensor filled with given value with the same shape as input tensor. | ||||
""" | """ | ||||
return full(inp.shape, value, dtype=inp.dtype, device=inp.device) | |||||
(x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) | |||||
if inp.shape is (): | |||||
return x | |||||
return broadcast_to(x, inp.shape) | |||||
def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | ||||
@@ -259,15 +264,10 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: | |||||
if len(inps) == 1: | if len(inps) == 1: | ||||
return inps[0] | return inps[0] | ||||
dtype = dtype_promotion(inps) | |||||
inps = convert_inputs(*inps, device=device) | |||||
if device is None: | if device is None: | ||||
device = get_device(inps) | device = get_device(inps) | ||||
device = as_device(device) | device = as_device(device) | ||||
def convert(x): | |||||
return convert_single_value(x, dtype=dtype, device=device) | |||||
inps = tuple(map(convert, inps)) | |||||
(result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps) | (result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps) | ||||
return result | return result | ||||
@@ -379,8 +379,14 @@ def split(inp, nsplits_or_sections, axis=0): | |||||
Ntotal, axis, Nsections | Ntotal, axis, Nsections | ||||
) | ) | ||||
) | ) | ||||
func = ( | |||||
floor_div | |||||
if isinstance(Nsections, (SymbolVar, Tensor)) | |||||
else lambda x, y: x // y | |||||
) | |||||
div_points = [0] + [ | div_points = [0] + [ | ||||
floor_div(Ntotal + Nsections - i - 1, Nsections) for i in range(Nsections) | |||||
func(Ntotal + Nsections - i - 1, Nsections) for i in range(Nsections) | |||||
] | ] | ||||
for i in range(2, Nsections + 1): | for i in range(2, Nsections + 1): | ||||
div_points[i] = div_points[i - 1] + div_points[i] | div_points[i] = div_points[i - 1] + div_points[i] | ||||
@@ -925,11 +931,15 @@ def linspace( | |||||
if not (cur_device is None or device == cur_device): | if not (cur_device is None or device == cur_device): | ||||
raise ("ambiguous device for linspace opr") | raise ("ambiguous device for linspace opr") | ||||
if not isinstance(start, Tensor): | |||||
is_symbolvar = list(isinstance(x, SymbolVar) for x in [start, stop, num]) | |||||
if any(is_symbolvar) and not all(is_symbolvar): | |||||
raise TypeError("start, stop and num should all be VarNode or none of them") | |||||
if not isinstance(start, (Tensor, SymbolVar)): | |||||
start = Tensor(start, device=device) | start = Tensor(start, device=device) | ||||
if not isinstance(stop, Tensor): | |||||
if not isinstance(stop, (Tensor, SymbolVar)): | |||||
stop = Tensor(stop, device=device) | stop = Tensor(stop, device=device) | ||||
if not isinstance(num, Tensor): | |||||
if not isinstance(num, (Tensor, SymbolVar)): | |||||
num = Tensor(num, device=device) | num = Tensor(num, device=device) | ||||
op = builtin.Linspace(comp_node=device) | op = builtin.Linspace(comp_node=device) | ||||
@@ -983,7 +993,7 @@ def arange( | |||||
stop = stop.astype("float32") | stop = stop.astype("float32") | ||||
if isinstance(step, Tensor): | if isinstance(step, Tensor): | ||||
step = step.astype("float32") | step = step.astype("float32") | ||||
num = ceil(Tensor((stop - start) / step, device=device)) | |||||
num = ceil((stop - start) / step) | |||||
stop = start + step * (num - 1) | stop = start + step * (num - 1) | ||||
result = linspace(start, stop, num, device=device) | result = linspace(start, stop, num, device=device) | ||||
if np.dtype(dtype) == np.int32: | if np.dtype(dtype) == np.int32: | ||||
@@ -607,10 +607,10 @@ class Module(metaclass=ABCMeta): | |||||
def __getattribute__(self, name: str): | def __getattribute__(self, name: str): | ||||
value = super().__getattribute__(name) | value = super().__getattribute__(name) | ||||
if name == "_name": | |||||
if name == "__dict__": | |||||
return value | return value | ||||
if isinstance(value, (Tensor, Module)): | |||||
value._name = name | |||||
for prefix, variable in _expand_structure(name, value): | |||||
variable._name = prefix | |||||
return value | return value | ||||
def __setattr__(self, name: str, value): | def __setattr__(self, name: str, value): | ||||
@@ -23,7 +23,7 @@ class Concat(QuantizedModule): | |||||
self.output_dtype = dtype | self.output_dtype = dtype | ||||
def forward(self, inps: Iterable[Tensor], axis: int = 0): | def forward(self, inps: Iterable[Tensor], axis: int = 0): | ||||
new_inps = (x.astype(self.output_dtype) for x in inps) | |||||
new_inps = tuple(x.astype(self.output_dtype) for x in inps) | |||||
return F.concat(new_inps, axis) | return F.concat(new_inps, axis) | ||||
@classmethod | @classmethod | ||||
@@ -92,6 +92,7 @@ class Sequential(Module): | |||||
return [getattr(self, key) for key in self.layer_keys] | return [getattr(self, key) for key in self.layer_keys] | ||||
def forward(self, inp): | def forward(self, inp): | ||||
for layer in self.layer_values: | |||||
# avoid layer_values as a name prefix, see Module.__getattribute__ | |||||
for layer in [getattr(self, key) for key in self.layer_keys]: | |||||
inp = layer(inp) | inp = layer(inp) | ||||
return inp | return inp |
@@ -6,6 +6,7 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import copy | |||||
from abc import ABCMeta, abstractmethod | from abc import ABCMeta, abstractmethod | ||||
from collections.abc import Iterable | from collections.abc import Iterable | ||||
from typing import Dict | from typing import Dict | ||||
@@ -197,10 +198,11 @@ class Optimizer(metaclass=ABCMeta): | |||||
cur_id += 1 | cur_id += 1 | ||||
for param, st in self._state.items(): | for param, st in self._state.items(): | ||||
_st = copy.copy(st) | |||||
if not keep_var: | if not keep_var: | ||||
for k, v in st.items(): | for k, v in st.items(): | ||||
st[k] = v.numpy() | |||||
state[param2id[param]] = st | |||||
_st[k] = v.numpy() | |||||
state[param2id[param]] = _st | |||||
for group in self.param_groups: | for group in self.param_groups: | ||||
param_group = {k: v for k, v in group.items() if k != "params"} | param_group = {k: v for k, v in group.items() if k != "params"} | ||||
@@ -6,8 +6,16 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from .fake_quant import FakeQuantize | |||||
from .observer import Observer | |||||
from .fake_quant import TQT, FakeQuantize | |||||
from .observer import ( | |||||
ExponentialMovingAverageObserver, | |||||
HistogramObserver, | |||||
MinMaxObserver, | |||||
Observer, | |||||
PassiveObserver, | |||||
SyncExponentialMovingAverageObserver, | |||||
SyncMinMaxObserver, | |||||
) | |||||
from .qconfig import ( | from .qconfig import ( | ||||
QConfig, | QConfig, | ||||
calibration_qconfig, | calibration_qconfig, | ||||
@@ -8,14 +8,19 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import argparse | import argparse | ||||
import logging | import logging | ||||
import re | |||||
import numpy as np | import numpy as np | ||||
from megengine.core.tensor.dtype import is_quantize | from megengine.core.tensor.dtype import is_quantize | ||||
from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level | from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level | ||||
from megengine.utils.module_stats import ( | from megengine.utils.module_stats import ( | ||||
print_flops_stats, | |||||
print_params_stats, | |||||
enable_receptive_field, | |||||
get_op_stats, | |||||
get_param_stats, | |||||
print_op_stats, | |||||
print_param_stats, | |||||
print_summary, | |||||
sizeof_fmt, | sizeof_fmt, | ||||
) | ) | ||||
from megengine.utils.network import Network | from megengine.utils.network import Network | ||||
@@ -40,34 +45,41 @@ def visualize( | |||||
:param log_params: whether print and record params size. | :param log_params: whether print and record params size. | ||||
:param log_flops: whether print and record op flops. | :param log_flops: whether print and record op flops. | ||||
""" | """ | ||||
try: | |||||
from tensorboard.compat.proto.attr_value_pb2 import AttrValue | |||||
from tensorboard.compat.proto.config_pb2 import RunMetadata | |||||
from tensorboard.compat.proto.graph_pb2 import GraphDef | |||||
from tensorboard.compat.proto.node_def_pb2 import NodeDef | |||||
from tensorboard.compat.proto.step_stats_pb2 import ( | |||||
AllocatorMemoryUsed, | |||||
DeviceStepStats, | |||||
NodeExecStats, | |||||
StepStats, | |||||
) | |||||
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto | |||||
from tensorboard.compat.proto.versions_pb2 import VersionDef | |||||
from tensorboardX import SummaryWriter | |||||
except ImportError: | |||||
logger.error( | |||||
"TensorBoard and TensorboardX are required for visualize.", exc_info=True | |||||
) | |||||
return | |||||
if log_path: | |||||
try: | |||||
from tensorboard.compat.proto.attr_value_pb2 import AttrValue | |||||
from tensorboard.compat.proto.config_pb2 import RunMetadata | |||||
from tensorboard.compat.proto.graph_pb2 import GraphDef | |||||
from tensorboard.compat.proto.node_def_pb2 import NodeDef | |||||
from tensorboard.compat.proto.step_stats_pb2 import ( | |||||
AllocatorMemoryUsed, | |||||
DeviceStepStats, | |||||
NodeExecStats, | |||||
StepStats, | |||||
) | |||||
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto | |||||
from tensorboard.compat.proto.versions_pb2 import VersionDef | |||||
from tensorboardX import SummaryWriter | |||||
except ImportError: | |||||
logger.error( | |||||
"TensorBoard and TensorboardX are required for visualize.", | |||||
exc_info=True, | |||||
) | |||||
return | |||||
# FIXME: remove this after resolving "span dist too large" warning | # FIXME: remove this after resolving "span dist too large" warning | ||||
old_level = set_mgb_log_level(logging.ERROR) | old_level = set_mgb_log_level(logging.ERROR) | ||||
enable_receptive_field() | |||||
graph = Network.load(model_path) | graph = Network.load(model_path) | ||||
writer = SummaryWriter(log_path) | |||||
def process_name(name): | def process_name(name): | ||||
return name.replace(".", "/").encode(encoding="utf-8") | |||||
# nodes that start with point or contain float const will lead to display bug | |||||
if not re.match(r"^[+-]?\d*\.\d*", name): | |||||
name = name.replace(".", "/") | |||||
return name.encode(encoding="utf-8") | |||||
summary = [["item", "value"]] | |||||
node_list = [] | node_list = [] | ||||
flops_list = [] | flops_list = [] | ||||
params_list = [] | params_list = [] | ||||
@@ -84,78 +96,90 @@ def visualize( | |||||
node_oup = node.outputs[0] | node_oup = node.outputs[0] | ||||
inp_list = [process_name(var.owner.name) for var in node.inputs] | inp_list = [process_name(var.owner.name) for var in node.inputs] | ||||
attr = { | |||||
"_output_shapes": AttrValue( | |||||
list=AttrValue.ListValue( | |||||
shape=[ | |||||
TensorShapeProto( | |||||
dim=[TensorShapeProto.Dim(size=d) for d in node_oup.shape] | |||||
) | |||||
] | |||||
) | |||||
), | |||||
} | |||||
if hasattr(node, "calc_flops"): | |||||
flops_num = node.calc_flops() | |||||
if log_path: | |||||
# detail format see tensorboard/compat/proto/attr_value.proto | |||||
attr = { | |||||
"_output_shapes": AttrValue( | |||||
list=AttrValue.ListValue( | |||||
shape=[ | |||||
TensorShapeProto( | |||||
dim=[ | |||||
TensorShapeProto.Dim(size=d) for d in node_oup.shape | |||||
] | |||||
) | |||||
] | |||||
) | |||||
), | |||||
"params": AttrValue(s=str(node.params).encode(encoding="utf-8")), | |||||
"dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")), | |||||
} | |||||
flops_stats = get_op_stats(node, node.inputs, node.outputs) | |||||
if flops_stats is not None: | |||||
# add op flops attr | # add op flops attr | ||||
attr["flops"] = AttrValue(s=sizeof_fmt(flops_num).encode(encoding="utf-8")) | |||||
flops_list.append( | |||||
dict( | |||||
name=node.name, | |||||
class_name=node.type, | |||||
input_shapes=[i.shape for i in node.inputs], | |||||
output_shapes=[o.shape for o in node.outputs], | |||||
flops_num=flops_num, | |||||
flops_cum=0, | |||||
if log_path and hasattr(flops_stats, "flops_num"): | |||||
attr["flops"] = AttrValue( | |||||
s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8") | |||||
) | ) | ||||
) | |||||
flops_stats["name"] = node.name | |||||
flops_stats["class_name"] = node.type | |||||
flops_list.append(flops_stats) | |||||
if node.type == "ImmutableTensor": | if node.type == "ImmutableTensor": | ||||
param_dim = np.prod(node_oup.shape) | |||||
# TODO: consider other quantize dtypes | |||||
param_bytes = 1 if is_quantize(node_oup.dtype) else 4 | |||||
param_stats = get_param_stats(node.numpy()) | |||||
# add tensor size attr | # add tensor size attr | ||||
attr["size"] = AttrValue( | |||||
s=sizeof_fmt(param_dim * param_bytes).encode(encoding="utf-8") | |||||
) | |||||
params_list.append( | |||||
dict( | |||||
name=node.name, | |||||
shape=node_oup.shape, | |||||
param_dim=param_dim, | |||||
bits=param_bytes * 8, | |||||
size=param_dim * param_bytes, | |||||
size_cum=0, | |||||
mean="{:.2g}".format(node.numpy().mean()), | |||||
std="{:.2g}".format(node.numpy().std()), | |||||
if log_path: | |||||
attr["size"] = AttrValue( | |||||
s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8") | |||||
) | |||||
param_stats["name"] = node.name | |||||
params_list.append(param_stats) | |||||
if log_path: | |||||
node_list.append( | |||||
NodeDef( | |||||
name=process_name(node.name), | |||||
op=node.type, | |||||
input=inp_list, | |||||
attr=attr, | |||||
) | ) | ||||
) | ) | ||||
# FIXME(MGE-2165): nodes outside network module may lead to unknown display bug | |||||
if not len(node.name.split(".")) > 2 and not node in graph.input_vars: | |||||
continue | |||||
node_list.append( | |||||
NodeDef( | |||||
name=process_name(node.name), op=node.type, input=inp_list, attr=attr, | |||||
) | |||||
) | |||||
# summary | |||||
extra_info = { | |||||
"#ops": len(graph.all_oprs), | |||||
"#params": len(params_list), | |||||
} | |||||
total_flops, total_params = 0, 0 | |||||
total_flops, total_param_dims, total_param_size = 0, 0, 0 | |||||
if log_params: | if log_params: | ||||
total_params = print_params_stats(params_list, bar_length_max) | |||||
total_param_dims, total_param_size = print_param_stats( | |||||
params_list, bar_length_max | |||||
) | |||||
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) | |||||
extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||||
if log_flops: | if log_flops: | ||||
total_flops = print_flops_stats(flops_list, bar_length_max) | |||||
total_flops = print_op_stats(flops_list, bar_length_max) | |||||
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||||
if log_params and log_flops: | |||||
extra_info["flops/param_size"] = "{:3.3f}".format( | |||||
total_flops / total_param_size | |||||
) | |||||
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) | |||||
if log_path: | |||||
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) | |||||
device = "/device:CPU:0" | |||||
stepstats = RunMetadata( | |||||
step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)]) | |||||
) | |||||
writer._get_file_writer().add_graph((graph_def, stepstats)) | |||||
device = "/device:CPU:0" | |||||
stepstats = RunMetadata( | |||||
step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)]) | |||||
) | |||||
writer = SummaryWriter(log_path) | |||||
writer._get_file_writer().add_graph((graph_def, stepstats)) | |||||
print_summary(**extra_info) | |||||
# FIXME: remove this after resolving "span dist too large" warning | # FIXME: remove this after resolving "span dist too large" warning | ||||
_imperative_rt_logger.set_log_level(old_level) | _imperative_rt_logger.set_log_level(old_level) | ||||
return total_params, total_flops | |||||
return total_param_size, total_flops | |||||
def main(): | def main(): | ||||
@@ -164,7 +188,7 @@ def main(): | |||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
) | ) | ||||
parser.add_argument("model_path", help="dumped model path.") | parser.add_argument("model_path", help="dumped model path.") | ||||
parser.add_argument("log_path", help="tensorboard log path.") | |||||
parser.add_argument("--log_path", help="tensorboard log path.") | |||||
parser.add_argument( | parser.add_argument( | ||||
"--bar_length_max", | "--bar_length_max", | ||||
type=int, | type=int, | ||||
@@ -179,7 +203,20 @@ def main(): | |||||
parser.add_argument( | parser.add_argument( | ||||
"--log_flops", action="store_true", help="whether print and record op flops.", | "--log_flops", action="store_true", help="whether print and record op flops.", | ||||
) | ) | ||||
visualize(**vars(parser.parse_args())) | |||||
parser.add_argument( | |||||
"--all", | |||||
action="store_true", | |||||
help="whether print all stats. Tensorboard logs will be placed in './log' if not specified.", | |||||
) | |||||
args = parser.parse_args() | |||||
if args.all: | |||||
args.log_params = True | |||||
args.log_flops = True | |||||
if not args.log_path: | |||||
args.log_path = "./log" | |||||
kwargs = vars(args) | |||||
kwargs.pop("all") | |||||
visualize(**kwargs) | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
@@ -5,16 +5,17 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import contextlib | |||||
from functools import partial | from functools import partial | ||||
import numpy as np | import numpy as np | ||||
import tabulate | import tabulate | ||||
import megengine as mge | import megengine as mge | ||||
import megengine.core.tensor.dtype as dtype | |||||
import megengine.module as m | import megengine.module as m | ||||
import megengine.module.qat as qatm | import megengine.module.qat as qatm | ||||
import megengine.module.quantized as qm | import megengine.module.quantized as qm | ||||
from megengine.core.tensor.dtype import get_dtype_bit | |||||
from megengine.functional.tensor import zeros | from megengine.functional.tensor import zeros | ||||
try: | try: | ||||
@@ -26,61 +27,99 @@ logger = mge.get_logger(__name__) | |||||
logger.setLevel("INFO") | logger.setLevel("INFO") | ||||
CALC_FLOPS = {} | |||||
_calc_flops_dict = {} | |||||
_calc_receptive_field_dict = {} | |||||
def _register_modules(*modules): | |||||
def _receptive_field_fallback(module, inputs, outputs): | |||||
if not _receptive_field_enabled: | |||||
return | |||||
assert not hasattr(module, "_rf") | |||||
assert not hasattr(module, "_stride") | |||||
if len(inputs) == 0: | |||||
# TODO: support other dimension | |||||
module._rf = (1, 1) | |||||
module._stride = (1, 1) | |||||
return module._rf, module._stride | |||||
rf, stride = preprocess_receptive_field(module, inputs, outputs) | |||||
module._rf = rf | |||||
module._stride = stride | |||||
return rf, stride | |||||
# key tuple, impl_dict, fallback | |||||
_iter_list = [ | |||||
("flops_num", _calc_flops_dict, None), | |||||
( | |||||
("receptive_field", "stride"), | |||||
_calc_receptive_field_dict, | |||||
_receptive_field_fallback, | |||||
), | |||||
] | |||||
_receptive_field_enabled = False | |||||
def _register_dict(*modules, dict=None): | |||||
def callback(impl): | def callback(impl): | ||||
for module in modules: | for module in modules: | ||||
CALC_FLOPS[module] = impl | |||||
dict[module] = impl | |||||
return impl | return impl | ||||
return callback | return callback | ||||
@_register_modules( | |||||
m.Conv2d, | |||||
m.ConvTranspose2d, | |||||
m.LocalConv2d, | |||||
qm.Conv2d, | |||||
qm.ConvRelu2d, | |||||
qm.ConvBn2d, | |||||
qm.ConvBnRelu2d, | |||||
qatm.Conv2d, | |||||
qatm.ConvRelu2d, | |||||
qatm.ConvBn2d, | |||||
qatm.ConvBnRelu2d, | |||||
def register_flops(*modules): | |||||
return _register_dict(*modules, dict=_calc_flops_dict) | |||||
def register_receptive_field(*modules): | |||||
return _register_dict(*modules, dict=_calc_receptive_field_dict) | |||||
def enable_receptive_field(): | |||||
global _receptive_field_enabled | |||||
_receptive_field_enabled = True | |||||
def disable_receptive_field(): | |||||
global _receptive_field_enabled | |||||
_receptive_field_enabled = False | |||||
@register_flops( | |||||
m.Conv1d, m.Conv2d, m.Conv3d, m.ConvTranspose2d, m.LocalConv2d, m.DeformableConv2d | |||||
) | ) | ||||
def count_convNd(module, input, output): | |||||
def flops_convNd(module: m.Conv2d, inputs, outputs): | |||||
bias = 1 if module.bias is not None else 0 | bias = 1 if module.bias is not None else 0 | ||||
group = module.groups | |||||
ic = input[0].shape[1] | |||||
oc = output[0].shape[1] | |||||
goc = oc // group | |||||
gic = ic // group | |||||
N = output[0].shape[0] | |||||
HW = np.prod(output[0].shape[2:]) | |||||
# N x Cout x H x W x (Cin x Kw x Kh + bias) | # N x Cout x H x W x (Cin x Kw x Kh + bias) | ||||
return N * HW * goc * (gic * np.prod(module.kernel_size) + bias) | |||||
return np.prod(outputs[0].shape) * ( | |||||
module.in_channels // module.groups * np.prod(module.kernel_size) + bias | |||||
) | |||||
@_register_modules(m.ConvTranspose2d) | |||||
def count_deconvNd(module, input, output): | |||||
return np.prod(input[0].shape) * output[0].shape[1] * np.prod(module.kernel_size) | |||||
@register_flops(m.Linear) | |||||
def flops_linear(module: m.Linear, inputs, outputs): | |||||
bias = module.out_features if module.bias is not None else 0 | |||||
return np.prod(outputs[0].shape) * module.in_features + bias | |||||
@_register_modules(m.Linear, qatm.Linear, qm.Linear) | |||||
def count_linear(module, input, output): | |||||
return np.prod(output[0].shape) * module.in_features | |||||
@register_flops(m.BatchMatMulActivation) | |||||
def flops_batchmatmul(module: m.BatchMatMulActivation, inputs, outputs): | |||||
bias = 1 if module.bias is not None else 0 | |||||
x = inputs[0] | |||||
w = module.weight | |||||
batch_size = x.shape[0] | |||||
n, p = x.shape[1:] | |||||
_, m = w.shape[1:] | |||||
return n * (p + bias) * m * batch_size | |||||
# does not need import qat and quantized module since they inherit from float module. | # does not need import qat and quantized module since they inherit from float module. | ||||
hook_modules = ( | hook_modules = ( | ||||
m.Conv2d, | |||||
m.ConvTranspose2d, | |||||
m.LocalConv2d, | |||||
m.BatchNorm2d, | |||||
m.conv._ConvNd, | |||||
m.Linear, | m.Linear, | ||||
m.BatchMatMulActivation, | |||||
) | ) | ||||
@@ -106,22 +145,63 @@ def sizeof_fmt(num, suffix="B"): | |||||
return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix) | return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix) | ||||
def print_flops_stats(flops, bar_length_max=20): | |||||
flops_list = [i["flops_num"] for i in flops] | |||||
max_flops_num = max(flops_list + [0]) | |||||
# calc total flops and set flops_cum | |||||
def preprocess_receptive_field(module, inputs, outputs): | |||||
# TODO: support other dimensions | |||||
pre_rf = ( | |||||
max(getattr(i.owner, "_rf", (1, 1))[0] for i in inputs), | |||||
max(getattr(i.owner, "_rf", (1, 1))[1] for i in inputs), | |||||
) | |||||
pre_stride = ( | |||||
max(getattr(i.owner, "_stride", (1, 1))[0] for i in inputs), | |||||
max(getattr(i.owner, "_stride", (1, 1))[1] for i in inputs), | |||||
) | |||||
return pre_rf, pre_stride | |||||
def get_op_stats(module, inputs, outputs): | |||||
rst = { | |||||
"input_shapes": [i.shape for i in inputs], | |||||
"output_shapes": [o.shape for o in outputs], | |||||
} | |||||
valid_flag = False | |||||
for key, _dict, fallback in _iter_list: | |||||
for _type in _dict: | |||||
if isinstance(module, _type): | |||||
value = _dict[_type](module, inputs, outputs) | |||||
valid_flag = True | |||||
break | |||||
else: | |||||
if fallback is not None: | |||||
value = fallback(module, inputs, outputs) | |||||
continue | |||||
if isinstance(key, tuple): | |||||
assert isinstance(value, tuple) | |||||
for k, v in zip(key, value): | |||||
rst[k] = v | |||||
else: | |||||
rst[key] = value | |||||
if valid_flag: | |||||
return rst | |||||
else: | |||||
return None | |||||
return | |||||
def print_op_stats(flops, bar_length_max=20): | |||||
max_flops_num = max([i["flops_num"] for i in flops] + [0]) | |||||
total_flops_num = 0 | total_flops_num = 0 | ||||
for d in flops: | for d in flops: | ||||
total_flops_num += int(d["flops_num"]) | total_flops_num += int(d["flops_num"]) | ||||
d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs") | d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs") | ||||
for i in flops: | |||||
f = i["flops_num"] | |||||
i["flops"] = sizeof_fmt(f, suffix="OPs") | |||||
r = i["ratio"] = f / total_flops_num | |||||
i["percentage"] = "{:.2f}%".format(r * 100) | |||||
bar_length = int(f / max_flops_num * bar_length_max) | |||||
i["bar"] = "#" * bar_length | |||||
for d in flops: | |||||
ratio = d["ratio"] = d["flops_num"] / total_flops_num | |||||
d["percentage"] = "{:.2f}%".format(ratio * 100) | |||||
bar_length = int(d["flops_num"] / max_flops_num * bar_length_max) | |||||
d["bar"] = "#" * bar_length | |||||
d["flops"] = sizeof_fmt(d["flops_num"], suffix="OPs") | |||||
header = [ | header = [ | ||||
"name", | "name", | ||||
@@ -133,10 +213,13 @@ def print_flops_stats(flops, bar_length_max=20): | |||||
"percentage", | "percentage", | ||||
"bar", | "bar", | ||||
] | ] | ||||
if _receptive_field_enabled: | |||||
header.insert(4, "receptive_field") | |||||
header.insert(5, "stride") | |||||
total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") | total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") | ||||
total_var_size = sum( | total_var_size = sum( | ||||
sum(s[1] if len(s) > 1 else 0 for s in i["output_shapes"]) for i in flops | |||||
sum(s[1] if len(s) > 1 else 0 for s in d["output_shapes"]) for d in flops | |||||
) | ) | ||||
flops.append( | flops.append( | ||||
dict(name="total", flops=total_flops_str, output_shapes=total_var_size) | dict(name="total", flops=total_flops_str, output_shapes=total_var_size) | ||||
@@ -147,30 +230,44 @@ def print_flops_stats(flops, bar_length_max=20): | |||||
return total_flops_num | return total_flops_num | ||||
def print_params_stats(params, bar_length_max=20): | |||||
def get_param_stats(param: np.ndarray): | |||||
nbits = get_dtype_bit(param.dtype.name) | |||||
shape = param.shape | |||||
param_dim = np.prod(param.shape) | |||||
param_size = param_dim * nbits // 8 | |||||
return { | |||||
"dtype": param.dtype, | |||||
"shape": shape, | |||||
"mean": "{:.3g}".format(param.mean()), | |||||
"std": "{:.3g}".format(param.std()), | |||||
"param_dim": param_dim, | |||||
"nbits": nbits, | |||||
"size": param_size, | |||||
} | |||||
def print_param_stats(params, bar_length_max=20): | |||||
max_size = max([d["size"] for d in params] + [0]) | |||||
total_param_dims, total_param_size = 0, 0 | total_param_dims, total_param_size = 0, 0 | ||||
for d in params: | for d in params: | ||||
total_param_dims += int(d["param_dim"]) | total_param_dims += int(d["param_dim"]) | ||||
total_param_size += int(d["size"]) | total_param_size += int(d["size"]) | ||||
d["size"] = sizeof_fmt(d["size"]) | |||||
d["size_cum"] = sizeof_fmt(total_param_size) | d["size_cum"] = sizeof_fmt(total_param_size) | ||||
for d in params: | for d in params: | ||||
ratio = d["param_dim"] / total_param_dims | |||||
ratio = d["size"] / total_param_size | |||||
d["ratio"] = ratio | d["ratio"] = ratio | ||||
d["percentage"] = "{:.2f}%".format(ratio * 100) | d["percentage"] = "{:.2f}%".format(ratio * 100) | ||||
# construct bar | |||||
max_ratio = max([d["ratio"] for d in params]) | |||||
for d in params: | |||||
bar_length = int(d["ratio"] / max_ratio * bar_length_max) | |||||
bar_length = int(d["size"] / max_size * bar_length_max) | |||||
d["size_bar"] = "#" * bar_length | d["size_bar"] = "#" * bar_length | ||||
d["size"] = sizeof_fmt(d["size"]) | |||||
param_size = sizeof_fmt(total_param_size) | param_size = sizeof_fmt(total_param_size) | ||||
params.append(dict(name="total", param_dim=total_param_dims, size=param_size,)) | params.append(dict(name="total", param_dim=total_param_dims, size=param_size,)) | ||||
header = [ | header = [ | ||||
"name", | "name", | ||||
"dtype", | |||||
"shape", | "shape", | ||||
"mean", | "mean", | ||||
"std", | "std", | ||||
@@ -186,7 +283,13 @@ def print_params_stats(params, bar_length_max=20): | |||||
"param stats: \n" + tabulate.tabulate(dict2table(params, header=header)) | "param stats: \n" + tabulate.tabulate(dict2table(params, header=header)) | ||||
) | ) | ||||
return total_param_size | |||||
return total_param_dims, total_param_size | |||||
def print_summary(**kwargs): | |||||
data = [["item", "value"]] | |||||
data.extend(list(kwargs.items())) | |||||
logger.info("summary\n" + tabulate.tabulate(data)) | |||||
def module_stats( | def module_stats( | ||||
@@ -205,71 +308,53 @@ def module_stats( | |||||
:param log_params: whether print and record params size. | :param log_params: whether print and record params size. | ||||
:param log_flops: whether print and record op flops. | :param log_flops: whether print and record op flops. | ||||
""" | """ | ||||
disable_receptive_field() | |||||
def get_byteswidth(tensor): | |||||
if dtype.is_quantize(tensor.dtype): | |||||
return 1 | |||||
# elif dtype.is_bfloat16(tensor.dtype): | |||||
# return 2 | |||||
else: | |||||
return 4 | |||||
def module_stats_hook(module, input, output, name=""): | |||||
def module_stats_hook(module, inputs, outputs, name=""): | |||||
class_name = str(module.__class__).split(".")[-1].split("'")[0] | class_name = str(module.__class__).split(".")[-1].split("'")[0] | ||||
flops_fun = CALC_FLOPS.get(type(module)) | |||||
if callable(flops_fun): | |||||
flops_num = flops_fun(module, input, output) | |||||
if not isinstance(output, (list, tuple)): | |||||
output = [output] | |||||
flops.append( | |||||
dict( | |||||
name=name, | |||||
class_name=class_name, | |||||
input_shapes=[i.shape for i in input], | |||||
output_shapes=[o.shape for o in output], | |||||
flops_num=flops_num, | |||||
flops_cum=0, | |||||
) | |||||
) | |||||
flops_stats = get_op_stats(module, inputs, outputs) | |||||
if flops_stats is not None: | |||||
flops_stats["name"] = name | |||||
flops_stats["class_name"] = class_name | |||||
flops.append(flops_stats) | |||||
if hasattr(module, "weight") and module.weight is not None: | if hasattr(module, "weight") and module.weight is not None: | ||||
w = module.weight | w = module.weight | ||||
value = w.numpy() | |||||
param_dim = np.prod(w.shape) | |||||
param_bytes = get_byteswidth(w) | |||||
params.append( | |||||
dict( | |||||
name=name + "-w", | |||||
shape=w.shape, | |||||
param_dim=param_dim, | |||||
bits=param_bytes * 8, | |||||
size=param_dim * param_bytes, | |||||
size_cum=0, | |||||
mean="{:.2g}".format(value.mean()), | |||||
std="{:.2g}".format(value.std()), | |||||
) | |||||
) | |||||
param_stats = get_param_stats(w.numpy()) | |||||
param_stats["name"] = name + "-w" | |||||
params.append(param_stats) | |||||
if hasattr(module, "bias") and module.bias is not None: | if hasattr(module, "bias") and module.bias is not None: | ||||
b = module.bias | b = module.bias | ||||
value = b.numpy() | |||||
param_dim = np.prod(b.shape) | |||||
param_bytes = get_byteswidth(b) | |||||
params.append( | |||||
dict( | |||||
name=name + "-b", | |||||
shape=b.shape, | |||||
param_dim=param_dim, | |||||
bits=param_bytes * 8, | |||||
size=param_dim * param_bytes, | |||||
size_cum=0, | |||||
mean="{:.2g}".format(value.mean()), | |||||
std="{:.2g}".format(value.std()), | |||||
) | |||||
) | |||||
param_stats = get_param_stats(b.numpy()) | |||||
param_stats["name"] = name + "-b" | |||||
params.append(param_stats) | |||||
@contextlib.contextmanager | |||||
def adjust_stats(module, training=False): | |||||
"""Adjust module to training/eval mode temporarily. | |||||
Args: | |||||
module (M.Module): used module. | |||||
training (bool): training mode. True for train mode, False fro eval mode. | |||||
""" | |||||
def recursive_backup_stats(module, mode): | |||||
for m in module.modules(): | |||||
# save prev status to _prev_training | |||||
m._prev_training = m.training | |||||
m.train(mode, recursive=False) | |||||
def recursive_recover_stats(module): | |||||
for m in module.modules(): | |||||
# recover prev status and delete attribute | |||||
m.training = m._prev_training | |||||
delattr(m, "_prev_training") | |||||
recursive_backup_stats(module, mode=training) | |||||
yield module | |||||
recursive_recover_stats(module) | |||||
# multiple inputs to the network | # multiple inputs to the network | ||||
if not isinstance(input_size[0], tuple): | if not isinstance(input_size[0], tuple): | ||||
@@ -286,15 +371,28 @@ def module_stats( | |||||
) | ) | ||||
inputs = [zeros(in_size, dtype=np.float32) for in_size in input_size] | inputs = [zeros(in_size, dtype=np.float32) for in_size in input_size] | ||||
model.eval() | |||||
model(*inputs) | |||||
with adjust_stats(model, training=False) as model: | |||||
model(*inputs) | |||||
for h in hooks: | for h in hooks: | ||||
h.remove() | h.remove() | ||||
total_flops, total_params = 0, 0 | |||||
extra_info = { | |||||
"#params": len(params), | |||||
} | |||||
total_flops, total_param_dims, total_param_size = 0, 0, 0 | |||||
if log_params: | if log_params: | ||||
total_params = print_params_stats(params, bar_length_max) | |||||
total_param_dims, total_param_size = print_param_stats(params, bar_length_max) | |||||
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) | |||||
extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||||
if log_flops: | if log_flops: | ||||
total_flops = print_flops_stats(flops, bar_length_max) | |||||
total_flops = print_op_stats(flops, bar_length_max) | |||||
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||||
if log_params and log_flops: | |||||
extra_info["flops/param_size"] = "{:3.3f}".format( | |||||
total_flops / total_param_size | |||||
) | |||||
print_summary(**extra_info) | |||||
return total_params, total_flops | |||||
return total_param_size, total_flops |
@@ -11,12 +11,14 @@ import fnmatch | |||||
import itertools | import itertools | ||||
import re | import re | ||||
from collections import OrderedDict | from collections import OrderedDict | ||||
from typing import Dict, List | |||||
from typing import Dict, List, Sequence | |||||
import numpy as np | import numpy as np | ||||
from ..core._imperative_rt import ComputingGraph | from ..core._imperative_rt import ComputingGraph | ||||
from ..core._imperative_rt.core2 import SymbolVar | |||||
from ..core.tensor import megbrain_graph as G | from ..core.tensor import megbrain_graph as G | ||||
from ..logger import get_logger | |||||
from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq | from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq | ||||
from .network_node import ( | from .network_node import ( | ||||
Host2DeviceCopy, | Host2DeviceCopy, | ||||
@@ -27,6 +29,8 @@ from .network_node import ( | |||||
str_to_mge_class, | str_to_mge_class, | ||||
) | ) | ||||
logger = get_logger(__name__) | |||||
class Network: | class Network: | ||||
def __init__(self): | def __init__(self): | ||||
@@ -60,12 +64,12 @@ class Network: | |||||
) | ) | ||||
outputs = [new_outputs[i] for i in outspec] | outputs = [new_outputs[i] for i in outspec] | ||||
self._orig_outputs = outputs | self._orig_outputs = outputs | ||||
self.add_dep_oprs(*outputs) | |||||
for x in self._orig_outputs: | |||||
self.output_vars.append(self._get_var(x)) | |||||
self.add_dep_oprs() | |||||
for x in self._orig_inputs: | for x in self._orig_inputs: | ||||
self.input_vars.append(self._get_var(x)) | self.input_vars.append(self._get_var(x)) | ||||
for x in self._orig_outputs: | |||||
self.output_vars.append(self._get_var(x)) | |||||
self.graph = self._orig_outputs[0].graph | self.graph = self._orig_outputs[0].graph | ||||
return self | return self | ||||
@@ -83,6 +87,58 @@ class Network: | |||||
for o in opr.outputs: | for o in opr.outputs: | ||||
self.all_vars_map[o.var.id] = o | self.all_vars_map[o.var.id] = o | ||||
def optimize_for_inference(self, dest_vars, **kwargs): | |||||
r""" | |||||
Applies optimize_for_inference pass for operator graph. | |||||
:param dest_vars: list of output vars in the operator graph | |||||
:Keyword Arguments: | |||||
* enable_io16xc32 -- | |||||
whether to use float16 for I/O between oprs and use | |||||
float32 as internal computation precision. Note the output var would be | |||||
changed to float16. | |||||
* enable_ioc16 -- | |||||
whether to use float16 for both I/O and computation | |||||
precision. | |||||
* enable_hwcd4 -- | |||||
whether to use NHWCD4 data layout. This is faster on some | |||||
OpenCL backend. | |||||
* enable_nchw88 -- | |||||
whether to use NCHW88 data layout, currently | |||||
used in X86 AVX backend. | |||||
* enable_nchw44 -- | |||||
whether to use NCHW44 data layout, currently | |||||
used in arm backend. | |||||
* enable_nchw44_dot -- | |||||
whether to use NCHW44_dot data layout, currently | |||||
used in armv8.2+dotprod backend. | |||||
* enable_nchw4 -- | |||||
whether to use NCHW4 data layout, currently | |||||
used in nvidia backend(based on cudnn). | |||||
* enable_nchw32 -- | |||||
whether to use NCHW32 data layout, currently | |||||
used in nvidia backend with tensorcore(based on cudnn). | |||||
* enable_chwn4 -- | |||||
whether to use CHWN4 data layout, currently | |||||
used in nvidia backend with tensorcore. | |||||
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | |||||
into one opr. | |||||
* enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z | |||||
input for inference on nvidia backend(this optimization pass will | |||||
result in mismatch of the precision of output of training and | |||||
inference) | |||||
""" | |||||
if not isinstance(dest_vars, Sequence): | |||||
dest_vars = [dest_vars] | |||||
dest_vars = list(G.VarNode(var.var) for var in dest_vars) | |||||
new_vars = G.optimize_for_inference(dest_vars, **kwargs) | |||||
return list(self._get_var(var) for var in new_vars) | |||||
def dump( | def dump( | ||||
self, | self, | ||||
file, | file, | ||||
@@ -122,47 +178,22 @@ class Network: | |||||
:Keyword Arguments: | :Keyword Arguments: | ||||
* enable_io16xc32 -- | |||||
whether to use float16 for I/O between oprs and use | |||||
float32 as internal computation precision. Note the output var would be | |||||
changed to float16. | |||||
* enable_ioc16 -- | |||||
whether to use float16 for both I/O and computation | |||||
precision. | |||||
* enable_hwcd4 -- | |||||
whether to use NHWCD4 data layout. This is faster on some | |||||
OpenCL backend. | |||||
* enable_nchw88 -- | |||||
whether to use NCHW88 data layout, currently | |||||
used in X86 AVX backend. | |||||
* enable_nchw44 -- | |||||
whether to use NCHW44 data layout, currently | |||||
used in arm backend. | |||||
* enable_nchw44_dot -- | |||||
whether to use NCHW44_dot data layout, currently | |||||
used in armv8.2+dotprod backend. | |||||
* enable_nchw4 -- | |||||
whether to use NCHW4 data layout, currently | |||||
used in nvidia backend(based on cudnn). | |||||
* enable_nchw32 -- | |||||
whether to use NCHW32 data layout, currently | |||||
used in nvidia backend with tensorcore(based on cudnn). | |||||
* enable_chwn4 -- | |||||
whether to use CHWN4 data layout, currently | |||||
used in nvidia backend with tensorcore. | |||||
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | |||||
into one opr. | |||||
* enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z | |||||
input for inference on nvidia backend(this optimization pass will | |||||
result in mismatch of the precision of output of training and | |||||
inference) | |||||
See also :py:meth:`optimize_for_inference`. | |||||
""" | """ | ||||
self._compile() | self._compile() | ||||
out = [G.VarNode(var.var) for var in self.output_vars] | out = [G.VarNode(var.var) for var in self.output_vars] | ||||
if kwargs.pop("arg_names", False): | |||||
logger.warning( | |||||
'"arg_names" is not supported in Network.dump, rename input vars directly' | |||||
) | |||||
if kwargs.pop("output_names", False): | |||||
logger.warning( | |||||
'"output_names" is not supported in Network.dump, rename output vars directly' | |||||
) | |||||
if optimize_for_inference: | if optimize_for_inference: | ||||
out = G.optimize_for_inference(out, **kwargs) | out = G.optimize_for_inference(out, **kwargs) | ||||
@@ -197,6 +228,8 @@ class Network: | |||||
def add_output(self, *vars: VarNode): | def add_output(self, *vars: VarNode): | ||||
"""Adds vars into the network output node list | """Adds vars into the network output node list | ||||
""" | """ | ||||
if not all([var.owner for var in vars]): | |||||
self.add_dep_oprs(*vars) | |||||
for var in vars: | for var in vars: | ||||
if var not in self.output_vars: | if var not in self.output_vars: | ||||
self.output_vars.append(var) | self.output_vars.append(var) | ||||
@@ -209,21 +242,25 @@ class Network: | |||||
self.output_vars.remove(var) | self.output_vars.remove(var) | ||||
def add_dep_oprs(self, *vars): | def add_dep_oprs(self, *vars): | ||||
"""Adds dependent opnodes and varnodes of vars into network | |||||
""" | |||||
oprs = get_oprs_seq(vars, False, False) | |||||
for mge_opr in oprs: | |||||
if len(vars) == 0: | |||||
vars = self.output_vars | |||||
q = list(vars) | |||||
while len(q) > 0: | |||||
cur = q.pop(0) | |||||
if cur.owner is not None: | |||||
continue | |||||
if cur.name is None: | |||||
cur.name = cur.var.name | |||||
self.all_vars_map[cur.var.id] = cur | |||||
mge_opr = cur.var.owner | |||||
if get_opr_type(mge_opr) == "Host2DeviceCopy": | if get_opr_type(mge_opr) == "Host2DeviceCopy": | ||||
self._orig_inputs.extend(mge_opr.outputs) | self._orig_inputs.extend(mge_opr.outputs) | ||||
opr = self._add_opr(mge_opr) | |||||
if opr is not None: | |||||
for x in mge_opr.inputs: | |||||
opr.add_inp_var(self._get_var(x)) | |||||
# set out var | |||||
for x in mge_opr.outputs: | |||||
opr.add_out_var(self._get_var(x)) | |||||
return [self.all_vars_map[var.id] for var in vars] | |||||
cur.owner = self._add_opr(mge_opr) | |||||
if cur.owner is None: | |||||
cur.owner = self.all_oprs_map[mge_opr.id] | |||||
continue | |||||
q.extend(cur.owner.inputs) | |||||
return list(vars) | |||||
def modify_opr_names(self, modifier): | def modify_opr_names(self, modifier): | ||||
"""Modifies names of operators **inplace**; useful for merging loaded | """Modifies names of operators **inplace**; useful for merging loaded | ||||
@@ -275,6 +312,9 @@ class Network: | |||||
Replaces vars in the graph. | Replaces vars in the graph. | ||||
:param repl_dict: the map {old_var: new_var} that specifies how to replace the vars. | :param repl_dict: the map {old_var: new_var} that specifies how to replace the vars. | ||||
""" | """ | ||||
if not all([var.owner for var in repl_dict.values()]): | |||||
print(repl_dict.values()) | |||||
self.add_dep_oprs(*list(repl_dict.values())) | |||||
for var in self.all_vars: | for var in self.all_vars: | ||||
if var in repl_dict: | if var in repl_dict: | ||||
repl_var = repl_dict[var] | repl_var = repl_dict[var] | ||||
@@ -282,6 +322,7 @@ class Network: | |||||
idx = owner.outputs.index(repl_var) | idx = owner.outputs.index(repl_var) | ||||
owner.outputs[idx] = var | owner.outputs[idx] = var | ||||
var.__dict__.update(repl_var.__dict__) | var.__dict__.update(repl_var.__dict__) | ||||
var.var = repl_var.var | |||||
def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): | def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): | ||||
""" | """ | ||||
@@ -297,6 +338,7 @@ class Network: | |||||
for ind, var in enumerate(opr.outputs): | for ind, var in enumerate(opr.outputs): | ||||
var.owner = repl_dict[opr] | var.owner = repl_dict[opr] | ||||
var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) | var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) | ||||
var.var = repl_dict[opr].outputs[ind].var | |||||
def get_opr_by_type(self, oprcls, unique=True): | def get_opr_by_type(self, oprcls, unique=True): | ||||
assert issubclass(oprcls, OpNode) | assert issubclass(oprcls, OpNode) | ||||
@@ -381,11 +423,16 @@ class Network: | |||||
return self.opr_filter.as_dict() | return self.opr_filter.as_dict() | ||||
# used for loading and building graph | # used for loading and building graph | ||||
def _add_opr(self, x): | |||||
def _add_opr(self, opr): | |||||
# TODO: use megbrain C++ RTTI to replace type string | # TODO: use megbrain C++ RTTI to replace type string | ||||
if x.id not in self.all_oprs_map: | |||||
self.all_oprs_map[x.id] = str_to_mge_class(get_opr_type(x)).load(x) | |||||
return self.all_oprs_map[x.id] | |||||
if opr.id not in self.all_oprs_map: | |||||
opnode = str_to_mge_class(get_opr_type(opr)).load(opr) | |||||
self.all_oprs_map[opr.id] = opnode | |||||
for var in opr.inputs: | |||||
opnode.add_inp_var(self._get_var(var)) | |||||
for var in opr.outputs: | |||||
opnode.add_out_var(self._get_var(var)) | |||||
return opnode | |||||
else: | else: | ||||
return None | return None | ||||
@@ -397,7 +444,7 @@ class Network: | |||||
def _get_var(self, x): | def _get_var(self, x): | ||||
# auto convert to VarNode of Network | # auto convert to VarNode of Network | ||||
if x.id not in self.all_vars_map: | |||||
if x.id not in self.all_vars_map or self.all_vars_map[x.id].var != x: | |||||
self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner)) | self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner)) | ||||
return self.all_vars_map[x.id] | return self.all_vars_map[x.id] | ||||
@@ -652,7 +699,7 @@ class NodeFilterHasInput(NodeFilter): | |||||
assert isinstance( | assert isinstance( | ||||
i, OpNode | i, OpNode | ||||
), "has_input() must be used with OpNode; " "got {!r}".format(i) | ), "has_input() must be used with OpNode; " "got {!r}".format(i) | ||||
if self.var in i.inputs: | |||||
if any(self.var is _ for _ in i.inputs): | |||||
yield i | yield i | ||||
@@ -663,6 +710,7 @@ class NodeFilterName(NodeFilter): | |||||
def __init__(self, node_iter, pattern, ignorecase): | def __init__(self, node_iter, pattern, ignorecase): | ||||
super().__init__(node_iter) | super().__init__(node_iter) | ||||
self.pattern = pattern | |||||
self._re = self.make_re(pattern, ignorecase) | self._re = self.make_re(pattern, ignorecase) | ||||
@classmethod | @classmethod | ||||
@@ -676,5 +724,5 @@ class NodeFilterName(NodeFilter): | |||||
def __iter__(self): | def __iter__(self): | ||||
for i in self._iter: | for i in self._iter: | ||||
if self._re.match(i.name): | |||||
if self.pattern == i.name or self._re.match(i.name): | |||||
yield i | yield i |
@@ -6,27 +6,41 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import abc | |||||
import json | import json | ||||
import sys | import sys | ||||
from typing import Callable | |||||
from typing import Callable, Sequence | |||||
import numpy as np | import numpy as np | ||||
from ..core import _imperative_rt as rt | from ..core import _imperative_rt as rt | ||||
from ..core._imperative_rt.core2 import SymbolVar | |||||
from ..core._wrap import Device | from ..core._wrap import Device | ||||
from ..core.ops import builtin | from ..core.ops import builtin | ||||
from ..core.tensor.megbrain_graph import InputNode | |||||
from ..core.tensor.array_method import ArrayMethodMixin | |||||
from ..core.tensor.indexing import getitem as _getitem | |||||
from ..core.tensor.indexing import setitem as _setitem | |||||
from ..core.tensor.megbrain_graph import InputNode, OutputNode | |||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
from .comp_graph_tools import replace_vars | from .comp_graph_tools import replace_vars | ||||
from .module_stats import ( | |||||
preprocess_receptive_field, | |||||
register_flops, | |||||
register_receptive_field, | |||||
) | |||||
class NetworkNode: | class NetworkNode: | ||||
pass | pass | ||||
class VarNode(NetworkNode): | |||||
def __init__(self, owner_opr=None, name=None): | |||||
self.var = None | |||||
class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)): | |||||
pass | |||||
class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): | |||||
def __init__(self, var=None, *, owner_opr=None, name=None): | |||||
SymbolVar.__init__(self, var) | |||||
self.owner = owner_opr | self.owner = owner_opr | ||||
self.name = name | self.name = name | ||||
self.id = id(self) | self.id = id(self) | ||||
@@ -53,6 +67,40 @@ class VarNode(NetworkNode): | |||||
def dtype(self): | def dtype(self): | ||||
return self.var.dtype if self.var else None | return self.var.dtype if self.var else None | ||||
def __bool__(self): | |||||
return False | |||||
__index__ = None | |||||
__int__ = None | |||||
__float__ = None | |||||
__complex__ = None | |||||
def __hash__(self): | |||||
return id(self) | |||||
@property | |||||
def _tuple_shape(self): | |||||
return self.var.shape | |||||
def numpy(self): | |||||
o = OutputNode(self.var) | |||||
self.graph.compile(o.outputs).execute() | |||||
return o.get_value().numpy() | |||||
def __getitem__(self, index): | |||||
return _getitem(self, index) | |||||
def __setitem__(self, index, value): | |||||
if index is not Ellipsis: | |||||
value = _setitem(self, index, value) | |||||
if self.owner is not None: | |||||
idx = self.owner.outputs.index(self) | |||||
self.owner.outputs[idx] = VarNode( | |||||
self.var, owner_opr=self.owner, name=self.var.name | |||||
) | |||||
self.var = value.var | |||||
self.owner = None | |||||
def set_owner_opr(self, owner_opr): | def set_owner_opr(self, owner_opr): | ||||
self.owner = owner_opr | self.owner = owner_opr | ||||
@@ -130,7 +178,7 @@ class Host2DeviceCopy(OpNode): | |||||
outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) | outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) | ||||
self._opr = outputs.owner | self._opr = outputs.owner | ||||
if len(self.outputs) == 0: | if len(self.outputs) == 0: | ||||
self.outputs.append(VarNode(self, self.name)) | |||||
self.outputs.append(VarNode(owner_opr=self, name=self.name)) | |||||
self.outputs[0].var = outputs | self.outputs[0].var = outputs | ||||
assert self.outputs[0].owner is self | assert self.outputs[0].owner is self | ||||
@@ -168,8 +216,8 @@ class ImmutableTensor(OpNode): | |||||
def set_value(self, data, device=None): | def set_value(self, data, device=None): | ||||
assert self.graph is not None | assert self.graph is not None | ||||
cn = device if device else self.device | cn = device if device else self.device | ||||
assert isinstance(data, (int, float, np.ndarray)) | |||||
if isinstance(data, (int, float)): | |||||
assert isinstance(data, (int, float, Sequence, np.ndarray)) | |||||
if not isinstance(data, np.ndarray): | |||||
data = np.array(data) | data = np.array(data) | ||||
if data.dtype == np.float64: | if data.dtype == np.float64: | ||||
data = data.astype(np.float32) | data = data.astype(np.float32) | ||||
@@ -177,7 +225,7 @@ class ImmutableTensor(OpNode): | |||||
data = data.astype(np.int32) | data = data.astype(np.int32) | ||||
varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name) | varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name) | ||||
if len(self.outputs) == 0: | if len(self.outputs) == 0: | ||||
self.outputs.append(VarNode(self, self.name)) | |||||
self.outputs.append(VarNode(owner_opr=self, name=self.name)) | |||||
self.outputs[0].var = varnode | self.outputs[0].var = varnode | ||||
self._opr = varnode.owner | self._opr = varnode.owner | ||||
@@ -225,8 +273,21 @@ class Elemwise(OpNode): | |||||
type = "Elemwise" | type = "Elemwise" | ||||
opdef = builtin.Elemwise | opdef = builtin.Elemwise | ||||
def calc_flops(self): | |||||
return np.prod(self.outputs[0].shape) | |||||
class ElemwiseMultiType(OpNode): | |||||
type = "ElemwiseMultiType" | |||||
opdef = builtin.ElemwiseMultiType | |||||
@classmethod | |||||
def load(cls, opr): | |||||
obj = super(ElemwiseMultiType, cls).load(opr) | |||||
obj.params["dtype"] = opr.outputs[0].dtype | |||||
return obj | |||||
@register_flops(Elemwise, ElemwiseMultiType) | |||||
def flops_elemwise(opnode: Elemwise, inputs, outputs): | |||||
return np.prod(outputs[0].shape) | |||||
class Reduce(OpNode): | class Reduce(OpNode): | ||||
@@ -255,20 +316,24 @@ class MatrixMul(OpNode): | |||||
type = "MatrixMul" | type = "MatrixMul" | ||||
opdef = builtin.MatrixMul | opdef = builtin.MatrixMul | ||||
def calc_flops(self): | |||||
assert len(self.inputs[0].shape) == 2 and len(self.outputs[0].shape) == 2 | |||||
mid_shape = self.inputs[0].shape[1] | |||||
return np.prod(self.outputs[0].shape) * mid_shape | |||||
@register_flops(MatrixMul) | |||||
def flops_matmul(opnode: MatrixMul, inputs, outputs): | |||||
assert len(inputs[0].shape) == 2 and len(outputs[0].shape) == 2 | |||||
mid_shape = inputs[0].shape[1] | |||||
return np.prod(outputs[0].shape) * mid_shape | |||||
class BatchedMatrixMul(OpNode): | class BatchedMatrixMul(OpNode): | ||||
type = "BatchedMatmul" | type = "BatchedMatmul" | ||||
opdef = builtin.BatchedMatrixMul | opdef = builtin.BatchedMatrixMul | ||||
def calc_flops(self): | |||||
assert len(self.inputs[0].shape) == 3 and len(self.outputs[0].shape) == 3 | |||||
mid_shape = self.inputs[0].shape[2] | |||||
return np.prod(self.outputs[0].shape) * mid_shape | |||||
@register_flops(BatchedMatrixMul) | |||||
def flops_batchmatmul(opnode: BatchedMatrixMul, inputs, outputs): | |||||
assert len(inputs[0].shape) == 3 and len(outputs[0].shape) == 3 | |||||
mid_shape = inputs[0].shape[2] | |||||
return np.prod(outputs[0].shape) * mid_shape | |||||
class Dot(OpNode): | class Dot(OpNode): | ||||
@@ -285,18 +350,6 @@ class ConvolutionForward(OpNode): | |||||
type = "Convolution" | type = "Convolution" | ||||
opdef = builtin.Convolution | opdef = builtin.Convolution | ||||
def calc_flops(self): | |||||
param_W_shape = self.inputs[1].shape | |||||
kh = param_W_shape[-2] | |||||
kw = param_W_shape[-1] | |||||
if len(param_W_shape) == 5: | |||||
num_input = param_W_shape[2] | |||||
else: | |||||
num_input = param_W_shape[1] | |||||
NCHW = np.prod(self.outputs[0].shape) | |||||
# N x Cout x H x W x (Cin x Kw x Kh) | |||||
return NCHW * (num_input * kw * kh) | |||||
class ConvolutionBackwardData(OpNode): | class ConvolutionBackwardData(OpNode): | ||||
type = "ConvTranspose" | type = "ConvTranspose" | ||||
@@ -343,17 +396,41 @@ class ConvBiasForward(OpNode): | |||||
obj.params["dtype"] = opr.outputs[0].dtype | obj.params["dtype"] = opr.outputs[0].dtype | ||||
return obj | return obj | ||||
def calc_flops(self): | |||||
param_W_shape = self.inputs[1].shape | |||||
kh = param_W_shape[-2] | |||||
kw = param_W_shape[-1] | |||||
if len(param_W_shape) == 5: | |||||
num_input = param_W_shape[2] | |||||
else: | |||||
num_input = param_W_shape[1] | |||||
NCHW = np.prod(self.outputs[0].shape) | |||||
# N x Cout x H x W x (Cin x Kw x Kh + bias) | |||||
return NCHW * (num_input * kw * kh + 1) | |||||
@register_flops( | |||||
ConvolutionForward, ConvBiasForward, | |||||
) | |||||
def flops_conv(opnode: ConvolutionForward, inputs, outputs): | |||||
param_W_shape = inputs[1].shape | |||||
kh = param_W_shape[-2] | |||||
kw = param_W_shape[-1] | |||||
if len(param_W_shape) == 5: | |||||
num_input = param_W_shape[2] | |||||
else: | |||||
num_input = param_W_shape[1] | |||||
NCHW = np.prod(outputs[0].shape) | |||||
bias = 1 if isinstance(opnode, ConvBiasForward) else 0 | |||||
# N x Cout x H x W x (Cin x Kw x Kh) | |||||
return NCHW * (num_input * kw * kh + bias) | |||||
@register_receptive_field(ConvolutionForward, ConvBiasForward) | |||||
def receptive_field(opnode: ConvolutionForward, inputs, outputs): | |||||
pre_rf, pre_stride = preprocess_receptive_field(opnode, inputs, outputs) | |||||
param_W_shape = inputs[1].shape | |||||
kh = param_W_shape[-2] | |||||
kw = param_W_shape[-1] | |||||
rf = ( | |||||
kh * pre_stride[0] + pre_rf[0] - pre_stride[0], | |||||
kw * pre_stride[1] + pre_rf[1] - pre_stride[1], | |||||
) | |||||
stride = ( | |||||
opnode.params["stride_h"] * pre_stride[0], | |||||
opnode.params["stride_w"] * pre_stride[1], | |||||
) | |||||
opnode._rf = rf | |||||
opnode._stride = stride | |||||
return rf, stride | |||||
class BatchConvBiasForward(OpNode): | class BatchConvBiasForward(OpNode): | ||||
@@ -652,20 +729,6 @@ class AssertEqual(OpNode): | |||||
opdef = builtin.AssertEqual | opdef = builtin.AssertEqual | ||||
class ElemwiseMultiType(OpNode): | |||||
type = "ElemwiseMultiType" | |||||
opdef = builtin.ElemwiseMultiType | |||||
@classmethod | |||||
def load(cls, opr): | |||||
obj = super(ElemwiseMultiType, cls).load(opr) | |||||
obj.params["dtype"] = opr.outputs[0].dtype | |||||
return obj | |||||
def calc_flops(self): | |||||
return np.prod(self.outputs[0].shape) | |||||
class CvtColorForward(OpNode): | class CvtColorForward(OpNode): | ||||
type = "CvtColor" | type = "CvtColor" | ||||
opdef = builtin.CvtColor | opdef = builtin.CvtColor |
@@ -266,7 +266,7 @@ void init_graph_rt(py::module m) { | |||||
{"HEURISTIC", [&]() { stg = _AlgoStrategy::HEURISTIC; }}, | {"HEURISTIC", [&]() { stg = _AlgoStrategy::HEURISTIC; }}, | ||||
{"PROFILE", [&]() { stg = _AlgoStrategy::PROFILE; }}, | {"PROFILE", [&]() { stg = _AlgoStrategy::PROFILE; }}, | ||||
{"REPRODUCIBLE", [&]() { stg = _AlgoStrategy::REPRODUCIBLE; }}, | {"REPRODUCIBLE", [&]() { stg = _AlgoStrategy::REPRODUCIBLE; }}, | ||||
{"OPTMIZED", [&]() { stg = _AlgoStrategy::OPTMIZED; }}, | |||||
{"OPTIMIZED", [&]() { stg = _AlgoStrategy::OPTIMIZED; }}, | |||||
}; | }; | ||||
auto it = m.find(strategy); | auto it = m.find(strategy); | ||||
mgb_assert(it != m.end(), "Invalid strategy string!"); | mgb_assert(it != m.end(), "Invalid strategy string!"); | ||||
@@ -87,9 +87,13 @@ struct pyobj_convert_generic { | |||||
} | } | ||||
}; | }; | ||||
template<typename T, typename SFINAE=void> | |||||
struct EnumTrait; | |||||
template <typename T> | template <typename T> | ||||
struct EnumTrait { | |||||
struct EnumTrait<T, std::enable_if_t<std::is_enum_v<T>>> { | |||||
static constexpr bool is_bit_combined = false; | static constexpr bool is_bit_combined = false; | ||||
static constexpr std::underlying_type_t<T> max = 0; | |||||
}; | }; | ||||
template <typename T> | template <typename T> | ||||
@@ -264,18 +268,25 @@ struct BitCombinedEnumWrapper { | |||||
return ret; | return ret; | ||||
} | } | ||||
} | } | ||||
static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject*, PyObject*) { | |||||
PyObject* obj = type->tp_alloc(type, 0); | |||||
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>(1); | |||||
return obj; | |||||
} | |||||
static int py_init(PyObject* self, PyObject* args, PyObject*) { | |||||
int input = 1; | |||||
if (PyArg_ParseTuple(args, "|i", &input)){ | |||||
reinterpret_cast<BitCombinedEnumWrapper*>(self)->value = | |||||
static_cast<T>(input); | |||||
static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject* args, PyObject*) { | |||||
if (!PyTuple_Size(args)) { | |||||
PyObject* obj = type->tp_alloc(type, 0); | |||||
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = T(); | |||||
return obj; | |||||
} | |||||
else { | |||||
PyObject* input; | |||||
if (!PyArg_ParseTuple(args, "|O", &input)) { | |||||
return nullptr; | |||||
} | |||||
T value; | |||||
try { | |||||
value = pyobj_convert_generic<T>::from(input); | |||||
} CATCH_ALL(nullptr); | |||||
PyObject* obj = type->tp_alloc(type, 0); | |||||
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = value; | |||||
return obj; | |||||
} | } | ||||
return 0; | |||||
} | } | ||||
static PyObject* py_repr(PyObject* self) { | static PyObject* py_repr(PyObject* self) { | ||||
return pyobj_convert_generic<std::string>::to( | return pyobj_convert_generic<std::string>::to( | ||||
@@ -325,6 +336,12 @@ struct pyobj_convert_generic<T, | |||||
static T from(PyObject* obj) { | static T from(PyObject* obj) { | ||||
if (PyObject_TypeCheck(obj, &Wrapper::type)) { | if (PyObject_TypeCheck(obj, &Wrapper::type)) { | ||||
return reinterpret_cast<Wrapper*>(obj)->value; | return reinterpret_cast<Wrapper*>(obj)->value; | ||||
} else if(PyLong_Check(obj)) { | |||||
auto value = pyobj_convert_generic<std::underlying_type_t<T>>::from(obj); | |||||
mgb_throw_if(value > EnumTrait<T>::max, mgb::MegBrainError, | |||||
"out of range, cannot convert %zu to %s", | |||||
static_cast<uint32_t>(value), Wrapper::name); | |||||
return static_cast<T>(value); | |||||
} | } | ||||
// try as string | // try as string | ||||
// TODO: type checkcd | // TODO: type checkcd | ||||
@@ -160,16 +160,21 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje | |||||
if (ctx.op->same_type<BackwardGraph>()) { | if (ctx.op->same_type<BackwardGraph>()) { | ||||
ctx.backward = true; | ctx.backward = true; | ||||
} | } | ||||
if (py::isinstance<cg::VarNode>(py::handle(args[0]))){ | |||||
SmallVector<cg::VarNode*> vinputs(nargs); | |||||
for (size_t i = 0; i < nargs; ++i) { | |||||
vinputs[i] = py::handle(args[i]).cast<cg::VarNode *>(); | |||||
} | |||||
auto op = ctx.op.get(); | |||||
return to_tuple(OpDef::apply_on_var_node(*op, vinputs)).release().ptr(); | |||||
} | |||||
if (py::isinstance<PySymbolVar>(py::handle(args[0]))){ | |||||
SmallVector<cg::VarNode*> vinputs(nargs); | |||||
for (size_t i = 0; i < nargs; ++i) { | |||||
vinputs[i] = py::handle(args[i]).cast<PySymbolVar*>()->m_node; | |||||
} | |||||
auto op = ctx.op.get(); | |||||
auto rst = OpDef::apply_on_var_node(*op, vinputs); | |||||
auto ret = pybind11::tuple(rst.size()); | |||||
auto typeobj = py::handle(args[0]).get_type(); | |||||
for (size_t i = 0; i<rst.size(); ++i) { | |||||
ret[i] = typeobj(pybind11::cast(rst[i], pybind11::return_value_policy::automatic)); | |||||
} | |||||
return ret.release().ptr(); | |||||
} | |||||
for (size_t i = 0; i < nargs; ++i) { | for (size_t i = 0; i < nargs; ++i) { | ||||
if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { | if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { | ||||
@@ -686,9 +691,9 @@ PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) { | |||||
continue; | continue; | ||||
} | } | ||||
if (py::isinstance<cg::VarNode>(py::handle(handle))){ | |||||
auto var = py::handle(handle).cast<cg::VarNode *>(); | |||||
mgb::DType type = var->dtype(); | |||||
if (py::isinstance<PySymbolVar>(py::handle(handle))){ | |||||
auto var = py::handle(handle).cast<PySymbolVar*>(); | |||||
mgb::DType type = var->m_node->dtype(); | |||||
auto && descr = npy::dtype_mgb2np_descr(type); | auto && descr = npy::dtype_mgb2np_descr(type); | ||||
Py_INCREF(descr.get()); | Py_INCREF(descr.get()); | ||||
tensors.emplace_back(descr.get()); | tensors.emplace_back(descr.get()); | ||||
@@ -737,19 +742,26 @@ CompNode _get_device(PyObject*const* args, size_t nargs) { | |||||
bool valid = false; | bool valid = false; | ||||
CompNode cn; | CompNode cn; | ||||
for (size_t i = 0; i < nargs; ++i) { | for (size_t i = 0; i < nargs; ++i) { | ||||
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; | |||||
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; | |||||
TensorWrapper* tw = TensorWrapper::try_cast(handle); | TensorWrapper* tw = TensorWrapper::try_cast(handle); | ||||
bool is_var = py::isinstance<cg::VarNode>(py::handle(handle)); | |||||
if (tw || is_var) { | |||||
bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle)); | |||||
if (tw || is_symvar) { | |||||
if (!valid) { | if (!valid) { | ||||
cn = tw ? tw->m_tensor->comp_node() : py::handle(handle).cast<cg::VarNode *>()->comp_node(); | |||||
cn = tw ? tw->m_tensor->comp_node() | |||||
: py::handle(handle) | |||||
.cast<PySymbolVar*>() | |||||
->m_node->comp_node(); | |||||
valid = true; | valid = true; | ||||
} else { | } else { | ||||
CompNode cn1 = tw ? tw->m_tensor->comp_node() : py::handle(handle).cast<cg::VarNode *>()->comp_node(); | |||||
CompNode cn1 = tw ? tw->m_tensor->comp_node() | |||||
: py::handle(handle) | |||||
.cast<PySymbolVar*>() | |||||
->m_node->comp_node(); | |||||
if (cn1 != cn) { | if (cn1 != cn) { | ||||
throw py::value_error(ssprintf("ambiguous device: %s vs %s", | throw py::value_error(ssprintf("ambiguous device: %s vs %s", | ||||
cn.to_string().c_str(), cn1.to_string().c_str())); | |||||
cn.to_string().c_str(), | |||||
cn1.to_string().c_str())); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -849,6 +861,32 @@ void init_tensor(py::module m) { | |||||
.def("__call__", &TensorWeakRef::operator()) | .def("__call__", &TensorWeakRef::operator()) | ||||
.def("_use_cnt", &TensorWeakRef::_use_cnt); | .def("_use_cnt", &TensorWeakRef::_use_cnt); | ||||
py::class_<PySymbolVar, std::shared_ptr<PySymbolVar>>(m, "SymbolVar") | |||||
.def_property_readonly( | |||||
"dtype", [](PySymbolVar* v) { return v->m_node->dtype(); }) | |||||
.def_property("var", [](PySymbolVar* v) { return v->m_node; }, | |||||
[](PySymbolVar* s, cg::VarNode* v) { s->m_node = v; }) | |||||
.def_property_readonly( | |||||
"device", | |||||
[](PySymbolVar* v) { return v->m_node->comp_node(); }) | |||||
.def_property_readonly( | |||||
"graph", | |||||
[](PySymbolVar* v) { return v->m_node->owner_graph(); }) | |||||
.def_property_readonly( | |||||
"shape", | |||||
[](PySymbolVar* v) -> const TensorShape* { | |||||
auto&& mgr = v->m_node->owner_graph() | |||||
->static_infer_manager(); | |||||
return mgr.infer_shape_fallible(v->m_node); | |||||
}) | |||||
.def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; }) | |||||
.def("_setscalar", | |||||
[](PySymbolVar* v) { return v->is_scalar = true; }) | |||||
.def(py::init([](cg::VarNode* node) { | |||||
return std::make_shared<PySymbolVar>(node); | |||||
}), | |||||
py::arg() = nullptr); | |||||
static PyMethodDef method_defs[] = { | static PyMethodDef method_defs[] = { | ||||
MGE_PY_INTERFACE(apply, py_apply), | MGE_PY_INTERFACE(apply, py_apply), | ||||
MGE_PY_INTERFACE(dtype_promotion, dtype_promotion), | MGE_PY_INTERFACE(dtype_promotion, dtype_promotion), | ||||
@@ -181,6 +181,12 @@ struct TensorWrapper { | |||||
PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); }; | PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); }; | ||||
}; | }; | ||||
struct PySymbolVar { | |||||
cg::VarNode* m_node = nullptr; | |||||
bool is_scalar = false; | |||||
PySymbolVar() = default; | |||||
PySymbolVar(VarNode *m): m_node(m){} | |||||
}; | |||||
PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */); | PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */); | ||||
@@ -2,9 +2,11 @@ import io | |||||
import numpy as np | import numpy as np | ||||
import megengine.core.tensor.megbrain_graph as G | |||||
import megengine.utils.comp_graph_tools as cgtools | import megengine.utils.comp_graph_tools as cgtools | ||||
from megengine import tensor | from megengine import tensor | ||||
from megengine.jit import trace | from megengine.jit import trace | ||||
from megengine.utils.network_node import VarNode | |||||
def _default_compare_fn(x, y): | def _default_compare_fn(x, y): | ||||
@@ -14,8 +16,23 @@ def _default_compare_fn(x, y): | |||||
np.testing.assert_allclose(x.numpy(), y, rtol=1e-6) | np.testing.assert_allclose(x.numpy(), y, rtol=1e-6) | ||||
def make_tensor(x, network=None, device=None): | |||||
if network is not None: | |||||
if isinstance(x, VarNode): | |||||
return VarNode(x.var) | |||||
return network.make_const(x, device=device) | |||||
else: | |||||
return tensor(x, device=device) | |||||
def opr_test( | def opr_test( | ||||
cases, func, compare_fn=_default_compare_fn, ref_fn=None, test_trace=True, **kwargs | |||||
cases, | |||||
func, | |||||
compare_fn=_default_compare_fn, | |||||
ref_fn=None, | |||||
test_trace=True, | |||||
network=None, | |||||
**kwargs | |||||
): | ): | ||||
""" | """ | ||||
:param cases: the list which have dict element, the list length should be 2 for dynamic shape test. | :param cases: the list which have dict element, the list length should be 2 for dynamic shape test. | ||||
@@ -44,7 +61,7 @@ def opr_test( | |||||
if not isinstance(results, (tuple, list)): | if not isinstance(results, (tuple, list)): | ||||
results = (results,) | results = (results,) | ||||
for r, e in zip(results, expected): | for r, e in zip(results, expected): | ||||
if not isinstance(r, tensor): | |||||
if not isinstance(r, (tensor, VarNode)): | |||||
r = tensor(r) | r = tensor(r) | ||||
compare_fn(r, e) | compare_fn(r, e) | ||||
@@ -72,9 +89,9 @@ def opr_test( | |||||
raise ValueError("the input func should be callable") | raise ValueError("the input func should be callable") | ||||
inp, outp = get_param(cases, 0) | inp, outp = get_param(cases, 0) | ||||
inp_tensor = [tensor(inpi) for inpi in inp] | |||||
inp_tensor = [make_tensor(inpi, network) for inpi in inp] | |||||
if test_trace: | |||||
if test_trace and not network: | |||||
copied_inp = inp_tensor.copy() | copied_inp = inp_tensor.copy() | ||||
for symbolic in [False, True]: | for symbolic in [False, True]: | ||||
traced_func = trace(symbolic=symbolic)(func) | traced_func = trace(symbolic=symbolic)(func) | ||||
@@ -104,6 +104,10 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): | |||||
) | ) | ||||
step += 1 | step += 1 | ||||
check_func(ori_params, net.parameters(), step) | check_func(ori_params, net.parameters(), step) | ||||
try_state_dict = { | |||||
"net": net.state_dict(), | |||||
"opt": opt.state_dict(), | |||||
} | |||||
def test_sgd(): | def test_sgd(): | ||||
@@ -10,12 +10,17 @@ import collections | |||||
import numpy as np | import numpy as np | ||||
import pytest | import pytest | ||||
from utils import make_tensor | |||||
import megengine | import megengine | ||||
import megengine.core.tensor.megbrain_graph as G | |||||
import megengine.functional as F | |||||
from megengine.core._imperative_rt.core2 import apply | from megengine.core._imperative_rt.core2 import apply | ||||
from megengine.core._trace_option import use_symbolic_shape | from megengine.core._trace_option import use_symbolic_shape | ||||
from megengine.core.ops import builtin | from megengine.core.ops import builtin | ||||
from megengine.tensor import Tensor | from megengine.tensor import Tensor | ||||
from megengine.utils.network import Network | |||||
from megengine.utils.network_node import VarNode | |||||
def cvt_to_shape_desc(val, inpvar, config=None): | def cvt_to_shape_desc(val, inpvar, config=None): | ||||
@@ -387,108 +392,130 @@ def test_batched_mesh_indexing(): | |||||
# high level | # high level | ||||
def get_value(x): | |||||
if isinstance(x, VarNode): | |||||
var = x.var | |||||
o = G.OutputNode(var) | |||||
graph = x.graph | |||||
graph.compile(o.outputs).execute() | |||||
return o.get_value().numpy() | |||||
else: | |||||
return x.numpy() | |||||
@pytest.mark.parametrize("test_varnode", [True, False]) | |||||
def test_advance_indexing_high_level(test_varnode): | |||||
if test_varnode: | |||||
network = Network() | |||||
else: | |||||
network = None | |||||
def test_advance_indexing_high_level(): | |||||
x = np.arange(25).reshape(5, 5).astype("int32") | x = np.arange(25).reshape(5, 5).astype("int32") | ||||
d = np.arange(15).reshape(3, 5).astype("int32") | d = np.arange(15).reshape(3, 5).astype("int32") | ||||
xx = Tensor(x) | |||||
xx = make_tensor(x, network) | |||||
np.testing.assert_equal(x[1, :], xx[1, :].numpy()) | |||||
np.testing.assert_equal(x[:, 1], xx[:, 1].numpy()) | |||||
np.testing.assert_equal(x[1:3, :], xx[1:3, :].numpy()) | |||||
np.testing.assert_equal(x[1, :], get_value(xx[1, :])) | |||||
np.testing.assert_equal(x[:, 1], get_value(xx[:, 1])) | |||||
np.testing.assert_equal(x[1:3, :], get_value(xx[1:3, :])) | |||||
np.testing.assert_equal(x[:, :], xx[:, :].numpy()) | |||||
np.testing.assert_equal(x[1, 1], xx[1, 1].numpy()) | |||||
np.testing.assert_equal(x[:, :], get_value(xx[:, :])) | |||||
np.testing.assert_equal(x[1, 1], get_value(xx[1, 1])) | |||||
yy = xx[(0, 4, 2), :] | yy = xx[(0, 4, 2), :] | ||||
np.testing.assert_equal(x[(0, 4, 2), :], yy.numpy()) | |||||
np.testing.assert_equal(x[(0, 4, 2), :], get_value(yy)) | |||||
x_ = x.copy() | x_ = x.copy() | ||||
x_[(0, 4, 2), :] = d | x_[(0, 4, 2), :] = d | ||||
xx_ = Tensor(xx) | |||||
xx_ = make_tensor(xx, network) | |||||
xx_[(0, 4, 2), :] = d | xx_[(0, 4, 2), :] = d | ||||
np.testing.assert_equal(x_, xx_.numpy()) | |||||
np.testing.assert_equal(x_, get_value(xx_)) | |||||
x = np.arange(27).reshape(3, 3, 3).astype("int32") | x = np.arange(27).reshape(3, 3, 3).astype("int32") | ||||
xx = Tensor(x) | |||||
xx = make_tensor(x, network) | |||||
np.testing.assert_equal(x[1, :, :], xx[1, :, :].numpy()) | |||||
np.testing.assert_equal(x[1, :, 1], xx[1, :, 1].numpy()) | |||||
np.testing.assert_equal(x[1, 0:1, :], xx[1, 0:1, :].numpy()) | |||||
np.testing.assert_equal(x[0:1, 1, 1], xx[0:1, 1, 1].numpy()) | |||||
np.testing.assert_equal(x[:, 1, 1], xx[:, 1, 1].numpy()) | |||||
np.testing.assert_equal(x[:, 1], xx[:, 1].numpy()) | |||||
np.testing.assert_equal(x[1, 1:2], xx[1, 1:2].numpy()) | |||||
np.testing.assert_equal(x[1, :, :], get_value(xx[1, :, :])) | |||||
np.testing.assert_equal(x[1, :, 1], get_value(xx[1, :, 1])) | |||||
np.testing.assert_equal(x[1, 0:1, :], get_value(xx[1, 0:1, :])) | |||||
np.testing.assert_equal(x[0:1, 1, 1], get_value(xx[0:1, 1, 1])) | |||||
np.testing.assert_equal(x[:, 1, 1], get_value(xx[:, 1, 1])) | |||||
np.testing.assert_equal(x[:, 1], get_value(xx[:, 1])) | |||||
np.testing.assert_equal(x[1, 1:2], get_value(xx[1, 1:2])) | |||||
x_ = x.copy() | x_ = x.copy() | ||||
x_[1, 1, 1] = -1 | x_[1, 1, 1] = -1 | ||||
xx[1, 1, 1] = -1 | xx[1, 1, 1] = -1 | ||||
np.testing.assert_equal(x_, xx.numpy()) | |||||
np.testing.assert_equal(x_, get_value(xx)) | |||||
x_[:, 1, 1] = -2 | x_[:, 1, 1] = -2 | ||||
xx[:, 1, 1] = x_[:, 1, 1] | xx[:, 1, 1] = x_[:, 1, 1] | ||||
np.testing.assert_equal(x_, xx.numpy()) | |||||
np.testing.assert_equal(x_, get_value(xx)) | |||||
x_[0:1, :, 1] = -3 | x_[0:1, :, 1] = -3 | ||||
xx[0:1, :, 1] = x_[0:1, :, 1] | xx[0:1, :, 1] = x_[0:1, :, 1] | ||||
np.testing.assert_equal(x_, xx.numpy()) | |||||
np.testing.assert_equal(x_, get_value(xx)) | |||||
x_[0:1, :, 1] = -4 | x_[0:1, :, 1] = -4 | ||||
y = Tensor(x_) | |||||
y = make_tensor(x_, network) | |||||
xx[0:1, :, 1] = y[0:1, :, 1] | xx[0:1, :, 1] = y[0:1, :, 1] | ||||
np.testing.assert_equal(y.numpy(), xx.numpy()) | |||||
np.testing.assert_equal(get_value(y), get_value(xx)) | |||||
x[:] = 1 | x[:] = 1 | ||||
xx[:] = 1 | xx[:] = 1 | ||||
np.testing.assert_equal(x, xx.numpy()) | |||||
np.testing.assert_equal(x, get_value(xx)) | |||||
x = np.arange(9).reshape(3, 3).astype("int32") | x = np.arange(9).reshape(3, 3).astype("int32") | ||||
xx = Tensor(x) | |||||
xx = make_tensor(x, network) | |||||
y = np.array([1, 2]) | y = np.array([1, 2]) | ||||
yy = Tensor(y) | |||||
np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy()) | |||||
np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) | |||||
np.testing.assert_equal(x[:, y], xx[:, y].numpy()) | |||||
np.testing.assert_equal(x[:, y], xx[:, yy].numpy()) | |||||
yy = make_tensor(y, network) | |||||
np.testing.assert_equal(x[:, y[0]], get_value(xx[:, y[0]])) | |||||
np.testing.assert_equal(x[:, y[0]], get_value(xx[:, yy[0]])) | |||||
np.testing.assert_equal(x[:, y], get_value(xx[:, y])) | |||||
np.testing.assert_equal(x[:, y], get_value(xx[:, yy])) | |||||
x_ = x.copy() | x_ = x.copy() | ||||
x_[:, y[0]] = -1 | x_[:, y[0]] = -1 | ||||
xx_ = Tensor(x_) | |||||
xx_ = make_tensor(x_, network) | |||||
xx[:, yy[0]] = xx_[:, yy[0]] | xx[:, yy[0]] = xx_[:, yy[0]] | ||||
np.testing.assert_equal(x_, xx.numpy()) | |||||
np.testing.assert_equal(x_, get_value(xx)) | |||||
x_[:, y] = -1 | x_[:, y] = -1 | ||||
xx_ = Tensor(x_) | |||||
xx_ = make_tensor(x_, network) | |||||
xx[:, yy] = xx_[:, yy] | xx[:, yy] = xx_[:, yy] | ||||
np.testing.assert_equal(x_, xx.numpy()) | |||||
np.testing.assert_equal(x_, get_value(xx)) | |||||
x = np.arange(9).reshape(3, 3).astype("int32") | x = np.arange(9).reshape(3, 3).astype("int32") | ||||
xx = Tensor(x) | |||||
xx = make_tensor(x, network) | |||||
y = np.array([1]) | y = np.array([1]) | ||||
yy = Tensor(y) | |||||
np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy()) | |||||
np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) | |||||
np.testing.assert_equal(x[:, y], xx[:, y].numpy()) | |||||
yy = make_tensor(y, network) | |||||
np.testing.assert_equal(x[:, y[0]], get_value(xx[:, y[0]])) | |||||
np.testing.assert_equal(x[:, y[0]], get_value(xx[:, yy[0]])) | |||||
np.testing.assert_equal(x[:, y], get_value(xx[:, y])) | |||||
np.testing.assert_equal(x[:, y], xx[:, yy].numpy()) | |||||
np.testing.assert_equal(x[:, y], get_value(xx[:, yy])) | |||||
x = np.arange(9).reshape(3, 3).astype("int32") | x = np.arange(9).reshape(3, 3).astype("int32") | ||||
xx = Tensor(x) | |||||
np.testing.assert_equal(x[[0, 1], 0], xx[[0, 1], 0].numpy()) | |||||
np.testing.assert_equal(x[0:2, 0], xx[0:2, 0].numpy()) | |||||
def test_advance_indexing_with_bool(): | |||||
xx = make_tensor(x, network) | |||||
np.testing.assert_equal(x[[0, 1], 0], get_value(xx[[0, 1], 0])) | |||||
np.testing.assert_equal(x[0:2, 0], get_value(xx[0:2, 0])) | |||||
@pytest.mark.parametrize( | |||||
"test_varnode", [True, False], | |||||
) | |||||
def test_advance_indexing_with_bool(test_varnode): | |||||
if test_varnode: | |||||
network = Network() | |||||
else: | |||||
network = None | |||||
a = np.arange(9).reshape(3, 3).astype(np.float32) | a = np.arange(9).reshape(3, 3).astype(np.float32) | ||||
b = np.array([1, 2, 3]) | b = np.array([1, 2, 3]) | ||||
c = np.array([1, 2, 3]) | c = np.array([1, 2, 3]) | ||||
aa = Tensor(a) | |||||
bb = Tensor(b) | |||||
cc = Tensor(c) | |||||
np.testing.assert_equal(a[b == 1, c == 2], aa[bb == 1, cc == 2].numpy()) | |||||
aa = make_tensor(a, network) | |||||
bb = make_tensor(b, network) | |||||
cc = make_tensor(c, network) | |||||
np.testing.assert_equal(a[b == 1, c == 2], get_value(aa[bb == 1, cc == 2])) | |||||
a[b == 1, c == 2] = -1.0 | a[b == 1, c == 2] = -1.0 | ||||
aa[bb == 1, cc == 2] = -1.0 | aa[bb == 1, cc == 2] = -1.0 | ||||
np.testing.assert_equal(a, aa.numpy()) | |||||
np.testing.assert_equal(a, get_value(aa)) | |||||
a = np.arange(9).reshape(3, 3).astype(np.float32) | a = np.arange(9).reshape(3, 3).astype(np.float32) | ||||
b = np.array([False, True, True]) | b = np.array([False, True, True]) | ||||
@@ -11,13 +11,16 @@ import platform | |||||
import numpy as np | import numpy as np | ||||
import pytest | import pytest | ||||
from utils import opr_test | |||||
from utils import make_tensor, opr_test | |||||
import megengine.functional as F | import megengine.functional as F | ||||
from megengine import tensor | from megengine import tensor | ||||
from megengine.core._trace_option import use_symbolic_shape | from megengine.core._trace_option import use_symbolic_shape | ||||
from megengine.core.tensor import megbrain_graph as G | |||||
from megengine.core.tensor.utils import astensor1d | from megengine.core.tensor.utils import astensor1d | ||||
from megengine.distributed.helper import get_device_count_by_fork | from megengine.distributed.helper import get_device_count_by_fork | ||||
from megengine.utils.network import Network | |||||
from megengine.utils.network_node import VarNode | |||||
def test_eye(): | def test_eye(): | ||||
@@ -38,7 +41,13 @@ def test_eye(): | |||||
) | ) | ||||
def test_concat(): | |||||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||||
def test_concat(is_varnode): | |||||
if is_varnode: | |||||
network = Network() | |||||
else: | |||||
network = None | |||||
def get_data_shape(length: int): | def get_data_shape(length: int): | ||||
return (length, 2, 3) | return (length, 2, 3) | ||||
@@ -50,18 +59,30 @@ def test_concat(): | |||||
return F.concat([data1, data2]) | return F.concat([data1, data2]) | ||||
cases = [{"input": [data1, data2]}, {"input": [data1, data3]}] | cases = [{"input": [data1, data2]}, {"input": [data1, data3]}] | ||||
opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y])) | |||||
opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network) | |||||
def test_concat_device(): | |||||
data1 = tensor(np.random.random((3, 2, 2)).astype("float32"), device="cpu0") | |||||
data2 = tensor(np.random.random((2, 2, 2)).astype("float32"), device="cpu1") | |||||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||||
def test_concat_device(is_varnode): | |||||
if is_varnode: | |||||
network = Network() | |||||
else: | |||||
network = None | |||||
data1 = make_tensor(np.random.random((3, 2, 2)).astype("float32"), network, "cpu0") | |||||
data2 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu1") | |||||
out = F.concat([data1, data2], device="cpu0") | out = F.concat([data1, data2], device="cpu0") | ||||
assert str(out.device).split(":")[0] == "cpu0" | assert str(out.device).split(":")[0] == "cpu0" | ||||
def test_stack(): | |||||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||||
def test_stack(is_varnode): | |||||
if is_varnode: | |||||
network = Network() | |||||
else: | |||||
network = None | |||||
data1 = np.random.random((3, 2, 2)).astype("float32") | data1 = np.random.random((3, 2, 2)).astype("float32") | ||||
data2 = np.random.random((3, 2, 2)).astype("float32") | data2 = np.random.random((3, 2, 2)).astype("float32") | ||||
data3 = np.random.random((3, 2, 2)).astype("float32") | data3 = np.random.random((3, 2, 2)).astype("float32") | ||||
@@ -72,12 +93,20 @@ def test_stack(): | |||||
def run(data1, data2): | def run(data1, data2): | ||||
return F.stack([data1, data2], axis=ai) | return F.stack([data1, data2], axis=ai) | ||||
opr_test(cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai)) | |||||
opr_test( | |||||
cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai), network=network | |||||
) | |||||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||||
def test_split(is_varnode): | |||||
if is_varnode: | |||||
network = Network() | |||||
else: | |||||
network = None | |||||
def test_split(): | |||||
data = np.random.random((2, 3, 4, 5)).astype(np.float32) | data = np.random.random((2, 3, 4, 5)).astype(np.float32) | ||||
inp = tensor(data) | |||||
inp = make_tensor(data, network) | |||||
mge_out0 = F.split(inp, 2, axis=3) | mge_out0 = F.split(inp, 2, axis=3) | ||||
mge_out1 = F.split(inp, [3], axis=3) | mge_out1 = F.split(inp, [3], axis=3) | ||||
@@ -106,26 +135,42 @@ def test_split(): | |||||
assert str(e) == "Invalid nsplits_or_secions: [3, 3, 5]" | assert str(e) == "Invalid nsplits_or_secions: [3, 3, 5]" | ||||
def test_reshape(): | |||||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||||
def test_reshape(is_varnode): | |||||
if is_varnode: | |||||
network = Network() | |||||
else: | |||||
network = None | |||||
x = np.arange(6, dtype="float32") | x = np.arange(6, dtype="float32") | ||||
xx = tensor(x) | |||||
xx = make_tensor(x, network) | |||||
y = x.reshape(1, 2, 3) | y = x.reshape(1, 2, 3) | ||||
for shape in [ | for shape in [ | ||||
(1, 2, 3), | (1, 2, 3), | ||||
(1, -1, 3), | (1, -1, 3), | ||||
(1, tensor(-1), 3), | |||||
(1, make_tensor(-1, network), 3), | |||||
np.array([1, -1, 3], dtype="int32"), | np.array([1, -1, 3], dtype="int32"), | ||||
tensor([1, -1, 3]), | |||||
make_tensor([1, -1, 3], network), | |||||
]: | ]: | ||||
yy = F.reshape(xx, shape) | yy = F.reshape(xx, shape) | ||||
np.testing.assert_equal(yy.numpy(), y) | np.testing.assert_equal(yy.numpy(), y) | ||||
def test_reshape_shape_inference(): | |||||
x_shape_known = tensor([1, 2, 3, 4], dtype="float32") | |||||
x_shape_unknown = F.broadcast_to(tensor([1.0]), shape=tensor([1, 1, 1, 1]).sum()) | |||||
tshp_unknown = astensor1d((tensor([2]), tensor([2])), x_shape_known) | |||||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||||
def test_reshape_shape_inference(is_varnode): | |||||
if is_varnode: | |||||
network = Network() | |||||
else: | |||||
network = None | |||||
x_shape_known = make_tensor([1, 2, 3, 4], network) | |||||
x_shape_unknown = F.broadcast_to( | |||||
make_tensor([1.0], network), shape=make_tensor([1, 1, 1, 1], network).sum() | |||||
) | |||||
tshp_unknown = astensor1d( | |||||
(make_tensor([2], network), make_tensor([2], network)), x_shape_known | |||||
) | |||||
tshp_known = astensor1d((2, 2), x_shape_known) | tshp_known = astensor1d((2, 2), x_shape_known) | ||||
tshp_known_unspec = astensor1d((2, -1), x_shape_known) | tshp_known_unspec = astensor1d((2, -1), x_shape_known) | ||||
@@ -146,12 +191,18 @@ def test_reshape_shape_inference(): | |||||
{"input": [x_shape_unknown, tshp_known], "output": [(2, 2),]}, | {"input": [x_shape_unknown, tshp_known], "output": [(2, 2),]}, | ||||
{"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]}, | {"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]}, | ||||
] | ] | ||||
opr_test(cases, func, compare_fn=check_shape, test_trace=True) | |||||
opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network) | |||||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||||
def test_squeeze(is_varnode): | |||||
if is_varnode: | |||||
network = Network() | |||||
else: | |||||
network = None | |||||
def test_squeeze(): | |||||
x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1) | x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1) | ||||
xx = tensor(x) | |||||
xx = make_tensor(x, network) | |||||
for axis in [None, 3, -4, (3, -4)]: | for axis in [None, 3, -4, (3, -4)]: | ||||
y = np.squeeze(x, axis) | y = np.squeeze(x, axis) | ||||
@@ -159,9 +210,15 @@ def test_squeeze(): | |||||
np.testing.assert_equal(y, yy.numpy()) | np.testing.assert_equal(y, yy.numpy()) | ||||
def test_expand_dims(): | |||||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||||
def test_expand_dims(is_varnode): | |||||
if is_varnode: | |||||
network = Network() | |||||
else: | |||||
network = None | |||||
x = np.arange(6, dtype="float32").reshape(2, 3) | x = np.arange(6, dtype="float32").reshape(2, 3) | ||||
xx = tensor(x) | |||||
xx = make_tensor(x, network) | |||||
for axis in [2, -3, (3, -4), (1, -4)]: | for axis in [2, -3, (3, -4), (1, -4)]: | ||||
y = np.expand_dims(x, axis) | y = np.expand_dims(x, axis) | ||||
@@ -169,11 +226,17 @@ def test_expand_dims(): | |||||
np.testing.assert_equal(y, yy.numpy()) | np.testing.assert_equal(y, yy.numpy()) | ||||
def test_elemwise_dtype_promotion(): | |||||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||||
def test_elemwise_dtype_promotion(is_varnode): | |||||
if is_varnode: | |||||
network = Network() | |||||
else: | |||||
network = None | |||||
x = np.random.rand(2, 3).astype("float32") | x = np.random.rand(2, 3).astype("float32") | ||||
y = np.random.rand(1, 3).astype("float16") | y = np.random.rand(1, 3).astype("float16") | ||||
xx = tensor(x) | |||||
yy = tensor(y) | |||||
xx = make_tensor(x, network) | |||||
yy = make_tensor(y, network) | |||||
z = xx * yy | z = xx * yy | ||||
np.testing.assert_equal(z.numpy(), x * y) | np.testing.assert_equal(z.numpy(), x * y) | ||||
@@ -184,7 +247,13 @@ def test_elemwise_dtype_promotion(): | |||||
np.testing.assert_equal(z.numpy(), x - y) | np.testing.assert_equal(z.numpy(), x - y) | ||||
def test_linspace(): | |||||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||||
def test_linspace(is_varnode): | |||||
if is_varnode: | |||||
network = Network() | |||||
else: | |||||
network = None | |||||
cases = [ | cases = [ | ||||
{"input": [1, 9, 9]}, | {"input": [1, 9, 9]}, | ||||
{"input": [3, 10, 8]}, | {"input": [3, 10, 8]}, | ||||
@@ -193,6 +262,7 @@ def test_linspace(): | |||||
cases, | cases, | ||||
F.linspace, | F.linspace, | ||||
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), | ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), | ||||
network=network, | |||||
) | ) | ||||
cases = [ | cases = [ | ||||
@@ -203,20 +273,28 @@ def test_linspace(): | |||||
cases, | cases, | ||||
F.linspace, | F.linspace, | ||||
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), | ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), | ||||
network=network, | |||||
) | ) | ||||
cases = [ | cases = [ | ||||
{"input": [1, tensor(9), 9]}, | |||||
{"input": [tensor(1), 9, tensor(9)]}, | |||||
{"input": [1, make_tensor(9, network), 9]}, | |||||
{"input": [make_tensor(1, network), 9, make_tensor(9, network)]}, | |||||
] | ] | ||||
opr_test( | opr_test( | ||||
cases, | cases, | ||||
F.linspace, | F.linspace, | ||||
ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32), | ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32), | ||||
network=network, | |||||
) | ) | ||||
def test_arange(): | |||||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||||
def test_arange(is_varnode): | |||||
if is_varnode: | |||||
network = Network() | |||||
else: | |||||
network = None | |||||
cases = [ | cases = [ | ||||
{"input": [1, 9, 1]}, | {"input": [1, 9, 1]}, | ||||
{"input": [2, 10, 2]}, | {"input": [2, 10, 2]}, | ||||
@@ -225,6 +303,7 @@ def test_arange(): | |||||
cases, | cases, | ||||
F.arange, | F.arange, | ||||
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | ||||
network=network, | |||||
) | ) | ||||
cases = [ | cases = [ | ||||
@@ -235,6 +314,7 @@ def test_arange(): | |||||
cases, | cases, | ||||
F.arange, | F.arange, | ||||
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | ||||
network=network, | |||||
) | ) | ||||
cases = [ | cases = [ | ||||
@@ -245,20 +325,33 @@ def test_arange(): | |||||
cases, | cases, | ||||
F.arange, | F.arange, | ||||
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | ||||
network=network, | |||||
) | ) | ||||
def test_round(): | |||||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||||
def test_round(is_varnode): | |||||
if is_varnode: | |||||
network = Network() | |||||
else: | |||||
network = None | |||||
data1_shape = (15,) | data1_shape = (15,) | ||||
data2_shape = (25,) | data2_shape = (25,) | ||||
data1 = np.random.random(data1_shape).astype(np.float32) | data1 = np.random.random(data1_shape).astype(np.float32) | ||||
data2 = np.random.random(data2_shape).astype(np.float32) | data2 = np.random.random(data2_shape).astype(np.float32) | ||||
cases = [{"input": data1}, {"input": data2}] | cases = [{"input": data1}, {"input": data2}] | ||||
opr_test(cases, F.round, ref_fn=np.round) | |||||
opr_test(cases, F.round, ref_fn=np.round, network=network) | |||||
def test_flatten(): | |||||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||||
def test_flatten(is_varnode): | |||||
if is_varnode: | |||||
network = Network() | |||||
else: | |||||
network = None | |||||
data0_shape = (2, 3, 4, 5) | data0_shape = (2, 3, 4, 5) | ||||
data1_shape = (4, 5, 6, 7) | data1_shape = (4, 5, 6, 7) | ||||
data0 = np.random.random(data0_shape).astype(np.float32) | data0 = np.random.random(data0_shape).astype(np.float32) | ||||
@@ -273,7 +366,7 @@ def test_flatten(): | |||||
{"input": data0, "output": output0}, | {"input": data0, "output": output0}, | ||||
{"input": data1, "output": output1}, | {"input": data1, "output": output1}, | ||||
] | ] | ||||
opr_test(cases, F.flatten, compare_fn=compare_fn) | |||||
opr_test(cases, F.flatten, compare_fn=compare_fn, network=network) | |||||
output0 = (2, 3 * 4 * 5) | output0 = (2, 3 * 4 * 5) | ||||
output1 = (4, 5 * 6 * 7) | output1 = (4, 5 * 6 * 7) | ||||
@@ -281,7 +374,7 @@ def test_flatten(): | |||||
{"input": data0, "output": output0}, | {"input": data0, "output": output0}, | ||||
{"input": data1, "output": output1}, | {"input": data1, "output": output1}, | ||||
] | ] | ||||
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1) | |||||
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, network=network) | |||||
output0 = (2, 3, 4 * 5) | output0 = (2, 3, 4 * 5) | ||||
output1 = (4, 5, 6 * 7) | output1 = (4, 5, 6 * 7) | ||||
@@ -289,7 +382,7 @@ def test_flatten(): | |||||
{"input": data0, "output": output0}, | {"input": data0, "output": output0}, | ||||
{"input": data1, "output": output1}, | {"input": data1, "output": output1}, | ||||
] | ] | ||||
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2) | |||||
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2, network=network) | |||||
output0 = (2, 3 * 4, 5) | output0 = (2, 3 * 4, 5) | ||||
output1 = (4, 5 * 6, 7) | output1 = (4, 5 * 6, 7) | ||||
@@ -297,10 +390,23 @@ def test_flatten(): | |||||
{"input": data0, "output": output0}, | {"input": data0, "output": output0}, | ||||
{"input": data1, "output": output1}, | {"input": data1, "output": output1}, | ||||
] | ] | ||||
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, end_axis=2) | |||||
opr_test( | |||||
cases, | |||||
F.flatten, | |||||
compare_fn=compare_fn, | |||||
start_axis=1, | |||||
end_axis=2, | |||||
network=network, | |||||
) | |||||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||||
def test_broadcast(is_varnode): | |||||
if is_varnode: | |||||
network = Network() | |||||
else: | |||||
network = None | |||||
def test_broadcast(): | |||||
input1_shape = (20, 30) | input1_shape = (20, 30) | ||||
output1_shape = (30, 20, 30) | output1_shape = (30, 20, 30) | ||||
data1 = np.random.random(input1_shape).astype(np.float32) | data1 = np.random.random(input1_shape).astype(np.float32) | ||||
@@ -309,14 +415,19 @@ def test_broadcast(): | |||||
output2_shape = (20, 10, 20) | output2_shape = (20, 10, 20) | ||||
data2 = np.random.random(input2_shape).astype(np.float32) | data2 = np.random.random(input2_shape).astype(np.float32) | ||||
input3_shape = (10, 10) | |||||
output3_shape = (10, 10) | |||||
data3 = np.random.random(input3_shape).astype(np.float32) | |||||
def compare_fn(x, y): | def compare_fn(x, y): | ||||
assert x.shape[0] == y | assert x.shape[0] == y | ||||
cases = [ | cases = [ | ||||
{"input": [data1, output1_shape], "output": output1_shape}, | {"input": [data1, output1_shape], "output": output1_shape}, | ||||
{"input": [data2, output2_shape], "output": output2_shape}, | {"input": [data2, output2_shape], "output": output2_shape}, | ||||
{"input": [data3, output3_shape], "output": output3_shape}, | |||||
] | ] | ||||
opr_test(cases, F.broadcast_to, compare_fn=compare_fn) | |||||
opr_test(cases, F.broadcast_to, compare_fn=compare_fn, network=network) | |||||
x = F.ones((2, 1, 3)) | x = F.ones((2, 1, 3)) | ||||
with pytest.raises(RuntimeError): | with pytest.raises(RuntimeError): | ||||
@@ -329,35 +440,41 @@ def test_broadcast(): | |||||
F.broadcast_to(x, (1, 3)) | F.broadcast_to(x, (1, 3)) | ||||
def test_utils_astensor1d(): | |||||
reference = tensor(0) | |||||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||||
def test_utils_astensor1d(is_varnode): | |||||
if is_varnode: | |||||
network = Network() | |||||
else: | |||||
network = None | |||||
reference = make_tensor(0, network) | |||||
# literal | # literal | ||||
x = [1, 2, 3] | x = [1, 2, 3] | ||||
for dtype in [None, "float32"]: | for dtype in [None, "float32"]: | ||||
xx = astensor1d(x, reference, dtype=dtype) | xx = astensor1d(x, reference, dtype=dtype) | ||||
assert type(xx) is tensor | |||||
assert isinstance(xx, type(reference)) | |||||
np.testing.assert_equal(xx.numpy(), x) | np.testing.assert_equal(xx.numpy(), x) | ||||
# numpy array | # numpy array | ||||
x = np.asarray([1, 2, 3], dtype="int32") | x = np.asarray([1, 2, 3], dtype="int32") | ||||
for dtype in [None, "float32"]: | for dtype in [None, "float32"]: | ||||
xx = astensor1d(x, reference, dtype=dtype) | xx = astensor1d(x, reference, dtype=dtype) | ||||
assert type(xx) is tensor | |||||
assert isinstance(xx, type(reference)) | |||||
np.testing.assert_equal(xx.numpy(), x.astype(dtype) if dtype else x) | np.testing.assert_equal(xx.numpy(), x.astype(dtype) if dtype else x) | ||||
# tensor | # tensor | ||||
x = tensor([1, 2, 3], dtype="int32") | |||||
x = make_tensor([1, 2, 3], network) | |||||
for dtype in [None, "float32"]: | for dtype in [None, "float32"]: | ||||
xx = astensor1d(x, reference, dtype=dtype) | xx = astensor1d(x, reference, dtype=dtype) | ||||
assert type(xx) is tensor | |||||
assert isinstance(xx, type(reference)) | |||||
np.testing.assert_equal(xx.numpy(), x.numpy()) | np.testing.assert_equal(xx.numpy(), x.numpy()) | ||||
# mixed | # mixed | ||||
x = [1, tensor(2), 3] | |||||
x = [1, make_tensor(2, network), 3] | |||||
for dtype in [None, "float32"]: | for dtype in [None, "float32"]: | ||||
xx = astensor1d(x, reference, dtype=dtype) | xx = astensor1d(x, reference, dtype=dtype) | ||||
assert type(xx) is tensor | |||||
assert isinstance(xx, type(reference)) | |||||
np.testing.assert_equal(xx.numpy(), [1, 2, 3]) | np.testing.assert_equal(xx.numpy(), [1, 2, 3]) | ||||
@@ -377,35 +494,60 @@ def test_device(): | |||||
np.testing.assert_almost_equal(y5.numpy(), y6.numpy()) | np.testing.assert_almost_equal(y5.numpy(), y6.numpy()) | ||||
def test_identity(): | |||||
x = tensor(np.random.random((5, 10)).astype(np.float32)) | |||||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||||
def test_identity(is_varnode): | |||||
if is_varnode: | |||||
network = Network() | |||||
else: | |||||
network = None | |||||
x = make_tensor(np.random.random((5, 10)).astype(np.float32), network) | |||||
y = F.copy(x) | y = F.copy(x) | ||||
np.testing.assert_equal(y.numpy(), x) | np.testing.assert_equal(y.numpy(), x) | ||||
def copy_test(dst, src): | |||||
def copy_test(dst, src, network): | |||||
data = np.random.random((2, 3)).astype(np.float32) | data = np.random.random((2, 3)).astype(np.float32) | ||||
x = tensor(data, device=src) | |||||
x = make_tensor(data, device=src, network=network) | |||||
y = F.copy(x, dst) | y = F.copy(x, dst) | ||||
assert np.allclose(data, y.numpy()) | assert np.allclose(data, y.numpy()) | ||||
z = x.to(dst) | |||||
assert np.allclose(data, z.numpy()) | |||||
if network is None: | |||||
z = x.to(dst) | |||||
assert np.allclose(data, z.numpy()) | |||||
@pytest.mark.require_ngpu(1) | @pytest.mark.require_ngpu(1) | ||||
def test_copy_h2d(): | |||||
copy_test("cpu0", "gpu0") | |||||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||||
def test_copy_h2d(is_varnode): | |||||
if is_varnode: | |||||
network = Network() | |||||
else: | |||||
network = None | |||||
copy_test("cpu0", "gpu0", network=network) | |||||
@pytest.mark.require_ngpu(1) | @pytest.mark.require_ngpu(1) | ||||
def test_copy_d2h(): | |||||
copy_test("gpu0", "cpu0") | |||||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||||
def test_copy_d2h(is_varnode): | |||||
if is_varnode: | |||||
network = Network() | |||||
else: | |||||
network = None | |||||
copy_test("gpu0", "cpu0", network=network) | |||||
@pytest.mark.require_ngpu(2) | @pytest.mark.require_ngpu(2) | ||||
def test_copy_d2d(): | |||||
copy_test("gpu0", "gpu1") | |||||
copy_test("gpu0:0", "gpu0:1") | |||||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||||
def test_copy_d2d(is_varnode): | |||||
if is_varnode: | |||||
network = Network() | |||||
else: | |||||
network = None | |||||
copy_test("gpu0", "gpu1", network=network) | |||||
copy_test("gpu0:0", "gpu0:1", network=network) | |||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
@@ -420,7 +562,13 @@ def test_copy_d2d(): | |||||
((), 10, None), | ((), 10, None), | ||||
], | ], | ||||
) | ) | ||||
def test_repeat(shape, repeats, axis): | |||||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||||
def test_repeat(shape, repeats, axis, is_varnode): | |||||
if is_varnode: | |||||
network = Network() | |||||
else: | |||||
network = None | |||||
def repeat_func(inp): | def repeat_func(inp): | ||||
return F.repeat(inp=inp, repeats=repeats, axis=axis) | return F.repeat(inp=inp, repeats=repeats, axis=axis) | ||||
@@ -432,7 +580,10 @@ def test_repeat(shape, repeats, axis): | |||||
cases = [{"input": np.array(1.23)}] | cases = [{"input": np.array(1.23)}] | ||||
opr_test( | opr_test( | ||||
cases, repeat_func, ref_fn=lambda inp: np.repeat(inp, repeats, axis), | |||||
cases, | |||||
repeat_func, | |||||
ref_fn=lambda inp: np.repeat(inp, repeats, axis), | |||||
network=network, | |||||
) | ) | ||||
@@ -445,14 +596,16 @@ def test_repeat(shape, repeats, axis): | |||||
((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)), | ((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)), | ||||
], | ], | ||||
) | ) | ||||
def test_tile(shape, reps): | |||||
@pytest.mark.parametrize("is_varnode", [True]) | |||||
def test_tile(shape, reps, is_varnode): | |||||
if is_varnode: | |||||
network = Network() | |||||
else: | |||||
network = None | |||||
def tile_func(inp): | def tile_func(inp): | ||||
return F.tile(inp=inp, reps=reps) | return F.tile(inp=inp, reps=reps) | ||||
cases = [ | |||||
{"input": np.random.randn(*shape).astype("float32")}, | |||||
] | |||||
cases = [{"input": np.random.randn(*shape).astype("float32")}] | |||||
opr_test( | |||||
cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), | |||||
) | |||||
opr_test(cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), network=network) |
@@ -30,7 +30,10 @@ min_max_fakequant_qconfig = QConfig( | |||||
act_fake_quant=partial(FakeQuantize, dtype="qint8"), | act_fake_quant=partial(FakeQuantize, dtype="qint8"), | ||||
) | ) | ||||
inp_scale = np.float32(np.random.rand() + 1) | |||||
def gen_inp_scale(): | |||||
return np.float32(np.random.rand() + 1) | |||||
min_val = np.random.randint(-127, 0, size=(2,)).astype("float32") | min_val = np.random.randint(-127, 0, size=(2,)).astype("float32") | ||||
max_val = np.random.randint(1, 127, size=(2,)).astype("float32") | max_val = np.random.randint(1, 127, size=(2,)).astype("float32") | ||||
@@ -116,6 +119,7 @@ def test_dequant_stub(): | |||||
q_net.eval() | q_net.eval() | ||||
x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | ||||
inp_scale = gen_inp_scale() | |||||
x = fake_quant_act(x, inp_scale) | x = fake_quant_act(x, inp_scale) | ||||
x.qparams.scale = inp_scale | x.qparams.scale = inp_scale | ||||
@@ -192,6 +196,7 @@ def test_linear(): | |||||
init_qat_net(qat_net) | init_qat_net(qat_net) | ||||
x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | ||||
inp_scale = gen_inp_scale() | |||||
x = fake_quant_act(x, inp_scale) | x = fake_quant_act(x, inp_scale) | ||||
x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale)) | x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale)) | ||||
@@ -235,6 +240,7 @@ def test_conv(module): | |||||
init_qat_net(qat_net) | init_qat_net(qat_net) | ||||
x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32")) | x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32")) | ||||
inp_scale = gen_inp_scale() | |||||
x = fake_quant_act(x, inp_scale) | x = fake_quant_act(x, inp_scale) | ||||
x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale)) | x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale)) | ||||
@@ -269,3 +275,41 @@ def test_conv(module): | |||||
np.testing.assert_allclose(qat_without_fakequant, normal, atol=1e-5) | np.testing.assert_allclose(qat_without_fakequant, normal, atol=1e-5) | ||||
np.testing.assert_allclose(qat, fake_quant_normal, atol=act_scale) | np.testing.assert_allclose(qat, fake_quant_normal, atol=act_scale) | ||||
np.testing.assert_allclose(q, fake_quant_normal.numpy(), atol=act_scale) | np.testing.assert_allclose(q, fake_quant_normal.numpy(), atol=act_scale) | ||||
def test_concat(): | |||||
normal_net = Float.Concat() | |||||
normal_net.eval() | |||||
qat_net = QAT.Concat() | |||||
qat_net.eval() | |||||
disable_observer(qat_net) | |||||
propagate_qconfig(qat_net, min_max_fakequant_qconfig) | |||||
init_qat_net(qat_net) | |||||
inps = [] | |||||
inps_int8 = [] | |||||
for i in range(3): | |||||
inp_scale = gen_inp_scale() | |||||
inps.append(mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))) | |||||
inps[i] = fake_quant_act(inps[i], inp_scale) | |||||
inps[i].qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale)) | |||||
inps_int8.append(quant(inps[i], inp_scale)) | |||||
qat_from_float = QAT.Concat.from_float_module(normal_net) | |||||
qat_from_float.eval() | |||||
disable_fake_quant(qat_from_float) | |||||
disable_observer(qat_from_float) | |||||
q_net = Q.Concat.from_qat_module(qat_net) | |||||
q_net.eval() | |||||
normal = normal_net(inps) | |||||
qat_without_fakequant = qat_from_float(inps) | |||||
fake_quant_normal = fake_quant_act(normal_net(inps), act_scale) | |||||
qat = qat_net(inps) | |||||
q = q_net(inps_int8).numpy() * act_scale | |||||
np.testing.assert_allclose(qat_without_fakequant, normal) | |||||
np.testing.assert_allclose(qat, fake_quant_normal.numpy()) | |||||
np.testing.assert_allclose(q, fake_quant_normal.numpy()) |
@@ -124,6 +124,35 @@ def test_with_submodule(symbolic): | |||||
@pytest.mark.parametrize("symbolic", [False, True]) | @pytest.mark.parametrize("symbolic", [False, True]) | ||||
def test_with_submodule_in_container(symbolic): | |||||
class Simple(M.Module): | |||||
def __init__(self, name): | |||||
super().__init__() | |||||
self.name = name | |||||
self.l0 = [M.Linear(3, 3) for _ in range(2)] | |||||
self.l1 = tuple(self.l0) | |||||
self.l2 = dict(zip(["l2-0", "l2-1"], self.l0)) | |||||
def forward(self, x): | |||||
for i in range(2): | |||||
x = self.l0[i](x) | |||||
x = self.l1[i](x) | |||||
x = self.l2["l2-%d" % i](x) | |||||
return x | |||||
m = Simple("simple") | |||||
ops = _dump_and_load(m, symbolic) | |||||
assert ops[-1].outputs[0].name == "simple.l2.l2-1.ADD" | |||||
assert ops[-1].name == "simple.l2.l2-1.ADD" | |||||
assert ops[-2].name == "simple.l2.l2-1.MatrixMul" | |||||
assert ops[-3].name == "simple.l1.1.ADD" | |||||
assert ops[-4].name == "simple.l1.1.MatrixMul" | |||||
assert ops[-5].name == "simple.l0.1.ADD" | |||||
assert ops[-6].name == "simple.l0.1.MatrixMul" | |||||
@pytest.mark.parametrize("symbolic", [False, True]) | |||||
def test_named_submodule(symbolic): | def test_named_submodule(symbolic): | ||||
class Simple(M.Module): | class Simple(M.Module): | ||||
def __init__(self, name): | def __init__(self, name): | ||||
@@ -264,4 +293,4 @@ def test_quantized_module_user_naming_param(symbolic): | |||||
(matrix_mul_op,) = [op for op in ops if op.name == "simple.linear.MatrixMul"] | (matrix_mul_op,) = [op for op in ops if op.name == "simple.linear.MatrixMul"] | ||||
for var in matrix_mul_op.inputs: | for var in matrix_mul_op.inputs: | ||||
assert var.name in ("simple.quant.TypeCvt", "simple.linear.user-weight") | assert var.name in ("simple.quant.TypeCvt", "simple.linear.user-weight") | ||||
# BUG bias' name does not meet expectations because of astype operator after quantization | |||||
# WONTFIX: bias' name does not meet expectations because of astype operator after quantization |
@@ -34,13 +34,11 @@ def test_replace_var(): | |||||
vara = graph.var_filter.name("a").as_unique() | vara = graph.var_filter.name("a").as_unique() | ||||
varb = graph.var_filter.name("b").as_unique() | varb = graph.var_filter.name("b").as_unique() | ||||
out = F.mul(vara.var, varb.var) | |||||
out = F.mul(vara, varb) | |||||
out = F.relu(out) | out = F.relu(out) | ||||
var_list = graph.add_dep_oprs(out) | |||||
opnode = list(graph.opr_filter.has_input(vara)) | opnode = list(graph.opr_filter.has_input(vara)) | ||||
repl_dict = {opnode[0].outputs[0]: var_list[0]} | |||||
repl_dict = {opnode[0].outputs[0]: out} | |||||
graph.replace_vars(repl_dict) | graph.replace_vars(repl_dict) | ||||
modified_model = io.BytesIO() | modified_model = io.BytesIO() | ||||
@@ -72,14 +70,12 @@ def test_replace_opr(): | |||||
vara = graph.var_filter.name("a").as_unique() | vara = graph.var_filter.name("a").as_unique() | ||||
varb = graph.var_filter.name("b").as_unique() | varb = graph.var_filter.name("b").as_unique() | ||||
out1 = F.sub(vara.var, varb.var) | |||||
out1 = F.sub(vara, varb) | |||||
out1 = F.relu(out1) | out1 = F.relu(out1) | ||||
var_list = graph.add_dep_oprs(out1) | |||||
repl_opr = as_oprnode(var_list) | |||||
out1 = graph.add_dep_oprs(out1) | |||||
orig_opr = graph.opr_filter.has_input(vara).as_unique() | orig_opr = graph.opr_filter.has_input(vara).as_unique() | ||||
repl_dict = {orig_opr: repl_opr} | |||||
repl_dict = {orig_opr: out1[0].owner} | |||||
graph.replace_oprs(repl_dict) | graph.replace_oprs(repl_dict) | ||||
modified_model1 = io.BytesIO() | modified_model1 = io.BytesIO() | ||||
graph.dump(modified_model1) | graph.dump(modified_model1) | ||||
@@ -171,8 +167,7 @@ def test_add_input(): | |||||
inp_c = graph.make_input_node((2,), np.int32, name="c") | inp_c = graph.make_input_node((2,), np.int32, name="c") | ||||
varo = graph.var_filter.name("o").as_unique() | varo = graph.var_filter.name("o").as_unique() | ||||
out = F.add(varo.var, inp_c.var) | |||||
out = graph.add_dep_oprs(out)[0] | |||||
out = F.add(varo, inp_c) | |||||
out.name = "o1" | out.name = "o1" | ||||
graph.remove_output(varo) | graph.remove_output(varo) | ||||
graph.add_output(out) | graph.add_output(out) | ||||
@@ -206,12 +201,11 @@ def test_add_output(): | |||||
var_a = net.var_filter.name("a").as_unique() | var_a = net.var_filter.name("a").as_unique() | ||||
var_b = net.var_filter.name("b").as_unique() | var_b = net.var_filter.name("b").as_unique() | ||||
y = F.add(var_a.var, var_b.var) | |||||
y = F.add(var_a, var_b) | |||||
y = F.sigmoid(y) | y = F.sigmoid(y) | ||||
new_vars = net.add_dep_oprs(y)[0] | |||||
new_vars.name = "o1" | |||||
net.add_output(new_vars) | |||||
y.name = "o1" | |||||
net.add_output(y) | |||||
modified_model = io.BytesIO() | modified_model = io.BytesIO() | ||||
net.dump(modified_model) | net.dump(modified_model) | ||||
@@ -466,6 +466,19 @@ def test_topk(): | |||||
check_pygraph_dump(fwd, [x], [top, indices]) | check_pygraph_dump(fwd, [x], [top, indices]) | ||||
def test_nvof(): | |||||
if not is_cuda_available(): | |||||
return | |||||
src_shape = (4, 5, 224, 224, 4) | |||||
src = np.random.randint(0, 255, src_shape).astype("uint8") | |||||
src = Tensor(src) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(src): | |||||
return F.nn.nvof(src, precision=1) | |||||
result = fwd(src) | |||||
check_pygraph_dump(fwd, [src], [result]) | |||||
def test_random(): | def test_random(): | ||||
@@ -6,5 +6,5 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
__version__ = "1.3.0.dev" | |||||
__version__ = "1.3.1" | |||||
@@ -24,14 +24,14 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||||
return Broadcast::make(); | return Broadcast::make(); | ||||
} | } | ||||
cg::OperatorNodeBase* apply_on_var_node( | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | const OpDef& def, | ||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& op = def.cast_final_safe<Broadcast>(); | auto&& op = def.cast_final_safe<Broadcast>(); | ||||
size_t nr_inp = inputs.size(); | size_t nr_inp = inputs.size(); | ||||
mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); | mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); | ||||
OperatorNodeConfig config{op.make_name()}; | OperatorNodeConfig config{op.make_name()}; | ||||
return opr::Broadcast::make(inputs[0], inputs[1], config).node()->owner_opr(); | |||||
return opr::Broadcast::make(inputs[0], inputs[1], config); | |||||
} | } | ||||
bool valid_broadcast(const TensorShape& src_shape, | bool valid_broadcast(const TensorShape& src_shape, | ||||
@@ -1,6 +1,7 @@ | |||||
# mgb tablegen executable | # mgb tablegen executable | ||||
set(TABLE_TARGET mgb-mlir-autogen) | set(TABLE_TARGET mgb-mlir-autogen) | ||||
add_executable(${TABLE_TARGET} autogen.cpp) | |||||
file(GLOB_RECURSE SRCS ${CMAKE_CURRENT_SOURCE_DIR}/*.h ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) | |||||
add_executable(${TABLE_TARGET} ${SRCS}) | |||||
target_include_directories(${TABLE_TARGET} PRIVATE ${MLIR_LLVM_INCLUDE_DIR}) | target_include_directories(${TABLE_TARGET} PRIVATE ${MLIR_LLVM_INCLUDE_DIR}) | ||||
target_link_libraries(${TABLE_TARGET} PRIVATE LLVMTableGen MLIRTableGen LLVMSupport) | target_link_libraries(${TABLE_TARGET} PRIVATE LLVMTableGen MLIRTableGen LLVMSupport) | ||||
set(MGB_TABLEGEN_EXE ${TABLE_TARGET}) | set(MGB_TABLEGEN_EXE ${TABLE_TARGET}) | ||||
@@ -1,8 +1,17 @@ | |||||
#include <iostream> | |||||
#include <unordered_map> | |||||
#include <functional> | |||||
#include "./helper.h" | |||||
/** | |||||
* \file imperative/tablegen/autogen.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "./targets/cpp_class.h" | |||||
#include "./targets/pybind11.h" | |||||
#include "./targets/python_c_extension.h" | |||||
using llvm::raw_ostream; | using llvm::raw_ostream; | ||||
using llvm::RecordKeeper; | using llvm::RecordKeeper; | ||||
@@ -27,731 +36,7 @@ llvm::cl::opt<ActionType> action( | |||||
clEnumValN(CPython, "gen-python-c-extension", | clEnumValN(CPython, "gen-python-c-extension", | ||||
"Generate python c extensions"))); | "Generate python c extensions"))); | ||||
using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase; | |||||
using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin; | |||||
using MgbHashableAttr = mlir::tblgen::MgbHashableAttrMixin; | |||||
using MgbAliasAttr = mlir::tblgen::MgbAliasAttrMixin; | |||||
using MgbOp = mlir::tblgen::MgbOpBase; | |||||
using MgbHashableOp = mlir::tblgen::MgbHashableOpMixin; | |||||
llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) { | |||||
// Note: we have already registered the corresponding attr wrappers | |||||
// for following basic ctypes so we needn't handle them here | |||||
/* auto&& attr_type_name = attr.getAttrDefName(); | |||||
if (attr_type_name == "UI32Attr") { | |||||
return "uint32_t"; | |||||
} | |||||
if (attr_type_name == "UI64Attr") { | |||||
return "uint64_t"; | |||||
} | |||||
if (attr_type_name == "I32Attr") { | |||||
return "int32_t"; | |||||
} | |||||
if (attr_type_name == "F32Attr") { | |||||
return "float"; | |||||
} | |||||
if (attr_type_name == "F64Attr") { | |||||
return "double"; | |||||
} | |||||
if (attr_type_name == "StrAttr") { | |||||
return "std::string"; | |||||
} | |||||
if (attr_type_name == "BoolAttr") { | |||||
return "bool"; | |||||
}*/ | |||||
auto&& attr = llvm::cast<MgbAttrWrapper>(attr_); | |||||
if (auto e = llvm::dyn_cast<MgbEnumAttr>(&attr)) { | |||||
return e->getEnumName(); | |||||
} | |||||
return attr.getUnderlyingType(); | |||||
} | |||||
static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) { | |||||
os << formatv( | |||||
"class {0} : public OpDefImplBase<{0}> {{\n" | |||||
" MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n" | |||||
"public:\n", | |||||
op.getCppClassName() | |||||
); | |||||
// handle enum alias | |||||
for (auto &&i : op.getMgbAttributes()) { | |||||
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||||
os << formatv( | |||||
" using {0} = {1};\n", | |||||
attr->getEnumName(), attr->getUnderlyingType() | |||||
); | |||||
} | |||||
} | |||||
for (auto &&i : op.getMgbAttributes()) { | |||||
auto defaultValue = i.attr.getDefaultValue().str(); | |||||
if (!defaultValue.empty()) { | |||||
defaultValue = formatv(" = {0}", defaultValue); | |||||
} | |||||
os << formatv( | |||||
" {0} {1}{2};\n", | |||||
attr_to_ctype(i.attr), i.name, defaultValue | |||||
); | |||||
} | |||||
auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) { | |||||
os << formatv( | |||||
" {0}({1}){2}{3}\n", | |||||
op.getCppClassName(), paramList, memInitList, body | |||||
); | |||||
}; | |||||
gen_ctor("", "", " = default;"); | |||||
if (!op.getMgbAttributes().empty()) { | |||||
std::vector<std::string> paramList, initList; | |||||
for (auto &&i : op.getMgbAttributes()) { | |||||
paramList.push_back(formatv( | |||||
"{0} {1}_", attr_to_ctype(i.attr), i.name | |||||
)); | |||||
initList.push_back(formatv( | |||||
"{0}({0}_)", i.name | |||||
)); | |||||
} | |||||
paramList.push_back("std::string scope_ = {}"); | |||||
gen_ctor(llvm::join(paramList, ", "), | |||||
": " + llvm::join(initList, ", "), | |||||
" { set_scope(scope_); }"); | |||||
} | |||||
auto packedParams = op.getPackedParams(); | |||||
if (!packedParams.empty()) { | |||||
std::vector<std::string> paramList, initList; | |||||
for (auto &&p : packedParams) { | |||||
auto&& paramFields = p.getFields(); | |||||
auto&& paramType = p.getFullName(); | |||||
auto&& paramName = formatv("packed_param_{0}", paramList.size()); | |||||
paramList.push_back( | |||||
paramFields.empty() ? paramType.str() | |||||
: formatv("{0} {1}", paramType, paramName) | |||||
); | |||||
for (auto&& i : paramFields) { | |||||
initList.push_back(formatv( | |||||
"{0}({1}.{0})", i.name, paramName | |||||
)); | |||||
} | |||||
} | |||||
for (auto&& i : op.getExtraArguments()) { | |||||
paramList.push_back(formatv( | |||||
"{0} {1}_", attr_to_ctype(i.attr), i.name | |||||
)); | |||||
initList.push_back(formatv( | |||||
"{0}({0}_)", i.name | |||||
)); | |||||
} | |||||
gen_ctor(llvm::join(paramList, ", "), | |||||
initList.empty() ? "" : ": " + llvm::join(initList, ", "), | |||||
" {}"); | |||||
} | |||||
if (!packedParams.empty()) { | |||||
for (auto&& p : packedParams) { | |||||
auto accessor = p.getAccessor(); | |||||
if (!accessor.empty()) { | |||||
os << formatv( | |||||
" {0} {1}() const {{\n", | |||||
p.getFullName(), accessor | |||||
); | |||||
std::vector<llvm::StringRef> fields; | |||||
for (auto&& i : p.getFields()) { | |||||
fields.push_back(i.name); | |||||
} | |||||
os << formatv( | |||||
" return {{{0}};\n", | |||||
llvm::join(fields, ", ") | |||||
); | |||||
os << " }\n"; | |||||
} | |||||
} | |||||
} | |||||
if (auto decl = op.getExtraOpdefDecl()) { | |||||
os << decl.getValue(); | |||||
} | |||||
os << formatv( | |||||
"};\n\n" | |||||
); | |||||
} | |||||
static void gen_to_string_trait_for_enum(raw_ostream &os, MgbOp& op) { | |||||
for (auto &&i : op.getMgbAttributes()) { | |||||
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||||
if (attr->supportToString()) { | |||||
std::vector<std::string> case_body; | |||||
std::string ename = formatv("{0}::{1}", | |||||
op.getCppClassName(), attr->getEnumName()); | |||||
llvm::for_each(attr->getEnumMembers(), [&](auto&& v){ | |||||
case_body.push_back(formatv( | |||||
"case {0}::{1}: return \"{1}\";", ename, v)); | |||||
}); | |||||
os << formatv(R"( | |||||
template <> | |||||
struct ToStringTrait<{0}> { | |||||
std::string operator()({0} e) const { | |||||
switch (e) { | |||||
{1} | |||||
default: | |||||
return "{0}::Unknown"; | |||||
} | |||||
} | |||||
}; | |||||
)", ename, llvm::join(case_body, "\n")); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | |||||
auto&& className = op.getCppClassName(); | |||||
os << formatv( | |||||
"MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className | |||||
); | |||||
auto formatMethImpl = [&](auto&& meth) { | |||||
return formatv( | |||||
"{0}_{1}_impl", className, meth | |||||
); | |||||
}; | |||||
std::vector<std::string> methods; | |||||
if (auto hashable = llvm::dyn_cast<MgbHashableOp>(&op)) { | |||||
os << "namespace {\n"; | |||||
// generate hash() | |||||
mlir::tblgen::FmtContext ctx; | |||||
os << formatv( | |||||
"size_t {0}(const OpDef& def_) {{\n", | |||||
formatMethImpl("hash") | |||||
); | |||||
os << formatv( | |||||
" auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||||
" static_cast<void>(op_);\n", | |||||
className | |||||
); | |||||
ctx.withSelf("op_"); | |||||
os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx); | |||||
os << "}\n"; | |||||
// generate is_same_st() | |||||
os << formatv( | |||||
"bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n", | |||||
formatMethImpl("is_same_st") | |||||
); | |||||
os << formatv( | |||||
" auto &&a_ = lhs_.cast_final_safe<{0}>(),\n" | |||||
" &&b_ = rhs_.cast_final_safe<{0}>();\n" | |||||
" static_cast<void>(a_);\n" | |||||
" static_cast<void>(b_);\n", | |||||
className | |||||
); | |||||
os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_"); | |||||
os << "}\n"; | |||||
// generate props() | |||||
os << formatv( | |||||
"std::vector<std::pair<const char*, std::string>> {0}(const OpDef& def_) {{\n", | |||||
formatMethImpl("props") | |||||
); | |||||
os << formatv( | |||||
" auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||||
" static_cast<void>(op_);\n", | |||||
className | |||||
); | |||||
ctx.withSelf("op_"); | |||||
os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx); | |||||
os << "}\n"; | |||||
// generate make_name() | |||||
os << formatv( | |||||
"std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name") | |||||
); | |||||
os << formatv( | |||||
" auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||||
" static_cast<void>(op_);\n", | |||||
className | |||||
); | |||||
ctx.withSelf("op_"); | |||||
os << mlir::tblgen::tgfmt(op.getNameFunctionTemplate(), &ctx); | |||||
os << "}\n"; | |||||
os << "} // anonymous namespace\n"; | |||||
methods.push_back("hash"); | |||||
methods.push_back("is_same_st"); | |||||
methods.push_back("props"); | |||||
methods.push_back("make_name"); | |||||
} | |||||
if (!methods.empty()) { | |||||
os << formatv( | |||||
"OP_TRAIT_REG({0}, {0})", op.getCppClassName() | |||||
); | |||||
for (auto&& i : methods) { | |||||
os << formatv( | |||||
"\n .{0}({1})", i, formatMethImpl(i) | |||||
); | |||||
} | |||||
os << ";\n\n"; | |||||
} | |||||
} | |||||
struct EnumContext { | |||||
std::unordered_map<unsigned int, std::pair<llvm::StringRef, llvm::StringRef>> enumAlias; | |||||
}; | |||||
static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { | |||||
auto className = op.getCppClassName(); | |||||
os << formatv( | |||||
"py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n", | |||||
className | |||||
); | |||||
for (auto&& i : op.getMgbAttributes()) { | |||||
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||||
unsigned int enumID; | |||||
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||||
auto&& aliasBase = alias->getAliasBase(); | |||||
enumID = | |||||
llvm::cast<MgbEnumAttr>(aliasBase) | |||||
.getBaseRecord()->getID(); | |||||
} else { | |||||
enumID = attr->getBaseRecord()->getID(); | |||||
} | |||||
auto&& enumAlias = ctx.enumAlias; | |||||
auto&& iter = enumAlias.find(enumID); | |||||
if (iter == enumAlias.end()) { | |||||
os << formatv( | |||||
"py::enum_<{0}::{1}>({0}Inst, \"{1}\")", | |||||
className, attr->getEnumName() | |||||
); | |||||
std::vector<std::string> body; | |||||
for (auto&& i: attr->getEnumMembers()) { | |||||
os << formatv( | |||||
"\n .value(\"{2}\", {0}::{1}::{2})", | |||||
className, attr->getEnumName(), i | |||||
); | |||||
body.push_back(formatv( | |||||
"if (str == \"{2}\") return {0}::{1}::{2};", | |||||
className, attr->getEnumName(), i | |||||
)); | |||||
} | |||||
if (attr->getEnumCombinedFlag()) { | |||||
//! define operator | | |||||
os << formatv( | |||||
"\n .def(\"__or__\", []({0}::{1} s0, {0}::{1} s1) {{ " | |||||
"\n return static_cast<{0}::{1}>(uint32_t(s0) | uint32_t(s1));" | |||||
"\n })", | |||||
className, attr->getEnumName()); | |||||
//! define operator & | |||||
os << formatv( | |||||
"\n .def(\"__and__\", []({0}::{1} s0, {0}::{1} s1) {{" | |||||
"\n return static_cast<{0}::{1}>(uint32_t(s0) & uint32_t(s1));" | |||||
"\n })", | |||||
className, attr->getEnumName()); | |||||
} | |||||
os << formatv( | |||||
"\n .def(py::init([](const std::string& in) {" | |||||
"\n auto&& str = normalize_enum(in);" | |||||
"\n {0}" | |||||
"\n throw py::cast_error(\"invalid enum value \" + in);" | |||||
"\n }));\n", | |||||
llvm::join(body, "\n ") | |||||
); | |||||
os << formatv( | |||||
"py::implicitly_convertible<std::string, {0}::{1}>();\n\n", | |||||
className, attr->getEnumName() | |||||
); | |||||
enumAlias.emplace(enumID, | |||||
std::make_pair(className, attr->getEnumName())); | |||||
} else { | |||||
os << formatv( | |||||
"{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n", | |||||
className, attr->getEnumName(), | |||||
iter->second.first, iter->second.second | |||||
); | |||||
} | |||||
} | |||||
} | |||||
// generate op class binding | |||||
os << formatv("{0}Inst", className); | |||||
bool hasDefaultCtor = op.getMgbAttributes().empty(); | |||||
if (!hasDefaultCtor) { | |||||
os << "\n .def(py::init<"; | |||||
std::vector<llvm::StringRef> targs; | |||||
for (auto &&i : op.getMgbAttributes()) { | |||||
targs.push_back(i.attr.getReturnType()); | |||||
} | |||||
os << llvm::join(targs, ", "); | |||||
os << ", std::string>()"; | |||||
for (auto &&i : op.getMgbAttributes()) { | |||||
os << formatv(", py::arg(\"{0}\")", i.name); | |||||
auto defaultValue = i.attr.getDefaultValue(); | |||||
if (!defaultValue.empty()) { | |||||
os << formatv(" = {0}", defaultValue); | |||||
} else { | |||||
hasDefaultCtor = true; | |||||
} | |||||
} | |||||
os << ", py::arg(\"scope\") = {})"; | |||||
} | |||||
if (hasDefaultCtor) { | |||||
os << "\n .def(py::init<>())"; | |||||
} | |||||
for (auto &&i : op.getMgbAttributes()) { | |||||
os << formatv( | |||||
"\n .def_readwrite(\"{0}\", &{1}::{0})", | |||||
i.name, className | |||||
); | |||||
} | |||||
os << ";\n\n"; | |||||
} | |||||
static std::string gen_op_def_python_c_extension_enum( | |||||
raw_ostream& os, EnumContext& ctx, MgbEnumAttr* attr, | |||||
llvm::StringRef className) { | |||||
std::string body; | |||||
unsigned int enumID; | |||||
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||||
auto&& aliasBase = alias->getAliasBase(); | |||||
enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID(); | |||||
} else { | |||||
enumID = attr->getBaseRecord()->getID(); | |||||
} | |||||
auto&& enumAlias = ctx.enumAlias; | |||||
auto&& iter = enumAlias.find(enumID); | |||||
auto enumName = attr->getEnumName(); | |||||
body += "{\n"; | |||||
body += formatv("auto& e_type = EnumWrapper<{0}::{1}>::type;", className, | |||||
enumName); | |||||
if (iter == enumAlias.end()) { | |||||
os << formatv( | |||||
"template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n", | |||||
className, enumName); | |||||
os << formatv( | |||||
"template<> const char* EnumWrapper<{0}::{1}>::name = " | |||||
"\"{0}.{1}\";\n", | |||||
className, enumName); | |||||
std::vector<std::string> pairStr; | |||||
for (auto&& i : attr->getEnumMembers()) { | |||||
pairStr.push_back( | |||||
formatv("{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", | |||||
className, enumName, i)); | |||||
} | |||||
os << formatv(R"( | |||||
template<> std::unordered_map<std::string, {0}::{1}> | |||||
EnumWrapper<{0}::{1}>::str2type = {{ | |||||
{2} | |||||
}; | |||||
)", | |||||
className, enumName, llvm::join(pairStr, ", ")); | |||||
pairStr.clear(); | |||||
for (auto&& i : attr->getEnumMembers()) { | |||||
pairStr.push_back( | |||||
formatv("{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", | |||||
className, enumName, i)); | |||||
} | |||||
os << formatv(R"( | |||||
template<> std::unordered_map<{0}::{1}, std::string> | |||||
EnumWrapper<{0}::{1}>::type2str = {{ | |||||
{2} | |||||
}; | |||||
)", | |||||
className, enumName, llvm::join(pairStr, ", ")); | |||||
body += formatv(R"( | |||||
e_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; | |||||
e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}"; | |||||
e_type.tp_basicsize = sizeof(EnumWrapper<{0}::{1}>); | |||||
e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||||
e_type.tp_doc = "{0}.{1}"; | |||||
e_type.tp_base = &PyBaseObject_Type; | |||||
e_type.tp_repr = EnumWrapper<{0}::{1}>::py_repr; | |||||
e_type.tp_richcompare = EnumWrapper<{0}::{1}>::tp_richcompare; | |||||
mgb_assert(PyType_Ready(&e_type) >= 0); | |||||
)", | |||||
className, enumName); | |||||
for (auto&& i : attr->getEnumMembers()) { | |||||
body += formatv(R"({{ | |||||
PyObject* inst = e_type.tp_alloc(&e_type, 0); | |||||
reinterpret_cast<EnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2}; | |||||
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0); | |||||
})", | |||||
className, enumName, i); | |||||
} | |||||
enumAlias.emplace(enumID, std::make_pair(className, enumName)); | |||||
} | |||||
body += formatv(R"( | |||||
PyType_Modified(&e_type); | |||||
mgb_assert(PyDict_SetItemString( | |||||
py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0); | |||||
)", | |||||
enumName); | |||||
body += "}\n"; | |||||
return body; | |||||
} | |||||
static std::string gen_op_def_python_c_extension_bit_combined_enum( | |||||
raw_ostream& os, EnumContext& ctx, MgbEnumAttr* attr, | |||||
llvm::StringRef className) { | |||||
std::string body; | |||||
unsigned int enumID; | |||||
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||||
auto&& aliasBase = alias->getAliasBase(); | |||||
enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID(); | |||||
} else { | |||||
enumID = attr->getBaseRecord()->getID(); | |||||
} | |||||
auto&& enumAlias = ctx.enumAlias; | |||||
auto&& iter = enumAlias.find(enumID); | |||||
auto enumName = attr->getEnumName(); | |||||
body += "{\n"; | |||||
body += formatv("auto& e_type = BitCombinedEnumWrapper<{0}::{1}>::type;", | |||||
className, enumName); | |||||
if (iter == enumAlias.end()) { | |||||
os << formatv( | |||||
"template<> PyTypeObject " | |||||
"BitCombinedEnumWrapper<{0}::{1}>::type={{};\n", | |||||
className, enumName); | |||||
os << formatv( | |||||
"template<> PyNumberMethods " | |||||
"BitCombinedEnumWrapper<{0}::{1}>::number_methods={{};\n", | |||||
className, enumName); | |||||
os << formatv( | |||||
"template<> const char* BitCombinedEnumWrapper<{0}::{1}>::name " | |||||
"= \"{0}.{1}\";\n", | |||||
className, enumName); | |||||
os << formatv( | |||||
"template<> struct EnumTrait<{0}::{1}> {{ static constexpr " | |||||
"bool is_bit_combined = true;};\n", | |||||
className, enumName); | |||||
std::vector<std::string> pairStr; | |||||
for (auto&& i : attr->getEnumMembers()) { | |||||
pairStr.push_back( | |||||
formatv("{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", | |||||
className, enumName, i)); | |||||
} | |||||
os << formatv(R"( | |||||
template<> std::unordered_map<std::string, {0}::{1}> | |||||
BitCombinedEnumWrapper<{0}::{1}>::str2type = {{ | |||||
{2} | |||||
}; | |||||
)", | |||||
className, enumName, llvm::join(pairStr, ", ")); | |||||
pairStr.clear(); | |||||
for (auto&& i : attr->getEnumMembers()) { | |||||
pairStr.push_back( | |||||
formatv("{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", | |||||
className, enumName, i)); | |||||
} | |||||
os << formatv(R"( | |||||
template<> std::unordered_map<{0}::{1}, std::string> | |||||
BitCombinedEnumWrapper<{0}::{1}>::type2str = {{ | |||||
{2} | |||||
}; | |||||
)", | |||||
className, enumName, llvm::join(pairStr, ", ")); | |||||
body += formatv(R"( | |||||
e_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; | |||||
e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}"; | |||||
e_type.tp_basicsize = sizeof(BitCombinedEnumWrapper<{0}::{1}>); | |||||
e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||||
e_type.tp_doc = "{0}.{1}"; | |||||
e_type.tp_base = &PyBaseObject_Type; | |||||
e_type.tp_new = BitCombinedEnumWrapper<{0}::{1}>::py_new_combined_enum; | |||||
e_type.tp_init = BitCombinedEnumWrapper<{0}::{1}>::py_init; | |||||
e_type.tp_repr = BitCombinedEnumWrapper<{0}::{1}>::py_repr; | |||||
e_type.tp_richcompare = BitCombinedEnumWrapper<{0}::{1}>::tp_richcompare; | |||||
auto& number_method = BitCombinedEnumWrapper<{0}::{1}>::number_methods; | |||||
number_method.nb_or = BitCombinedEnumWrapper<{0}::{1}>::py_or; | |||||
number_method.nb_and = BitCombinedEnumWrapper<{0}::{1}>::py_and; | |||||
e_type.tp_as_number = &number_method; | |||||
mgb_assert(PyType_Ready(&e_type) >= 0); | |||||
)", | |||||
className, enumName); | |||||
for (auto&& i : attr->getEnumMembers()) { | |||||
body += formatv(R"({{ | |||||
PyObject* inst = e_type.tp_alloc(&e_type, 0); | |||||
reinterpret_cast<BitCombinedEnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2}; | |||||
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0); | |||||
})", | |||||
className, enumName, i); | |||||
} | |||||
enumAlias.emplace(enumID, std::make_pair(className, enumName)); | |||||
} | |||||
body += formatv(R"( | |||||
PyType_Modified(&e_type); | |||||
mgb_assert(PyDict_SetItemString( | |||||
py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0); | |||||
)", | |||||
enumName); | |||||
body += "}\n"; | |||||
return body; | |||||
} | |||||
static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { | |||||
auto className = op.getCppClassName(); | |||||
std::string body; | |||||
// generate PyType for enum class member | |||||
for (auto&& i : op.getMgbAttributes()) { | |||||
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||||
if (attr->getEnumCombinedFlag()) { | |||||
body += gen_op_def_python_c_extension_bit_combined_enum( | |||||
os, ctx, attr, className); | |||||
} else { | |||||
body += gen_op_def_python_c_extension_enum(os, ctx, attr, | |||||
className); | |||||
} | |||||
} | |||||
} | |||||
// generate getsetters | |||||
std::vector<std::string> getsetters; | |||||
for (auto &&i : op.getMgbAttributes()) { | |||||
getsetters.push_back(formatv( | |||||
"{{const_cast<char*>(\"{1}\"), py_get_generic({0}, {1}), py_set_generic({0}, {1}), const_cast<char*>(\"{1}\"), NULL},", | |||||
className, i.name)); | |||||
} | |||||
// generate tp_init | |||||
std::string initBody; | |||||
if (!op.getMgbAttributes().empty()) { | |||||
initBody += "static const char* kwlist[] = {"; | |||||
std::vector<llvm::StringRef> attr_name_list; | |||||
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||||
attr_name_list.push_back(attr.name); | |||||
}); | |||||
attr_name_list.push_back("scope"); | |||||
llvm::for_each(attr_name_list, [&](auto&& attr) { | |||||
initBody += formatv("\"{0}\", ", attr); | |||||
}); | |||||
initBody += "NULL};\n"; | |||||
initBody += " PyObject "; | |||||
std::vector<std::string> attr_init; | |||||
llvm::for_each(attr_name_list, [&](auto&& attr) { | |||||
attr_init.push_back(formatv("*{0} = NULL", attr)); | |||||
}); | |||||
initBody += llvm::join(attr_init, ", ") + ";\n"; | |||||
initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|"; | |||||
// an extra slot created for name | |||||
initBody += std::string(attr_name_list.size(), 'O'); | |||||
initBody += "\", const_cast<char**>(kwlist)"; | |||||
llvm::for_each(attr_name_list, [&](auto&& attr) { | |||||
initBody += formatv(", &{0}", attr); | |||||
}); | |||||
initBody += "))\n"; | |||||
initBody += " return -1;\n"; | |||||
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||||
initBody += formatv(R"( | |||||
if ({1}) {{ | |||||
try {{ | |||||
reinterpret_cast<PyOp({0})*>(self)->inst().{1} = | |||||
pyobj_convert_generic<decltype({0}::{1})>::from({1}); | |||||
} CATCH_ALL(-1) | |||||
} | |||||
)", className, attr.name); | |||||
}); | |||||
initBody += formatv(R"( | |||||
if (scope) {{ | |||||
try {{ | |||||
reinterpret_cast<PyOp(OpDef)*>(self)->op | |||||
->set_scope(pyobj_convert_generic<std::string>::from(scope)); | |||||
} CATCH_ALL(-1) | |||||
} | |||||
)", className); | |||||
} | |||||
initBody += "\n return 0;"; | |||||
os << formatv(R"( | |||||
PyOpDefBegin({0}) // {{ | |||||
static PyGetSetDef py_getsetters[]; | |||||
static int py_init(PyObject *self, PyObject *args, PyObject *kwds); | |||||
// }; | |||||
PyOpDefEnd({0}) | |||||
PyGetSetDef PyOp({0})::py_getsetters[] = {{ | |||||
{1} | |||||
{{NULL} /* Sentinel */ | |||||
}; | |||||
int PyOp({0})::py_init(PyObject *self, PyObject *args, PyObject *kwds) {{ | |||||
{2} | |||||
} | |||||
void _init_py_{0}(py::module m) {{ | |||||
using py_op = PyOp({0}); | |||||
auto& py_type = PyOpType({0}); | |||||
py_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; | |||||
py_type.tp_name = "megengine.core._imperative_rt.ops.{0}"; | |||||
py_type.tp_basicsize = sizeof(PyOp({0})); | |||||
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||||
py_type.tp_doc = "{0}"; | |||||
py_type.tp_base = &PyOpType(OpDef); | |||||
py_type.tp_dealloc = py_dealloc_generic<py_op>; | |||||
py_type.tp_new = py_new_generic<py_op>; | |||||
py_type.tp_init = py_op::py_init; | |||||
py_type.tp_getset = py_op::py_getsetters; | |||||
mgb_assert(PyType_Ready(&py_type) >= 0); | |||||
{3} | |||||
PyType_Modified(&py_type); | |||||
m.add_object("{0}", reinterpret_cast<PyObject*>(&py_type)); | |||||
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace({0}::typeinfo(), &py_type).second); | |||||
} | |||||
)", | |||||
op.getCppClassName(), llvm::join(getsetters, "\n "), initBody, body); | |||||
} | |||||
static void for_each_operator(raw_ostream &os, RecordKeeper &keeper, | |||||
std::function<void(raw_ostream&, MgbOp&)> callback) { | |||||
auto op_base_class = keeper.getClass("Op"); | |||||
ASSERT(op_base_class, "could not find base class Op"); | |||||
for (auto&& i: keeper.getDefs()) { | |||||
auto&& r = i.second; | |||||
if (r->isSubClassOf(op_base_class)) { | |||||
auto op = mlir::tblgen::Operator(r.get()); | |||||
if (op.getDialectName().str() == "mgb") { | |||||
std::cerr << "\033[34;15m" << "Generating " << r->getName().str() << "\033[0m" << std::endl; | |||||
callback(os, llvm::cast<MgbOp>(op)); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
static bool gen_op_def_c_header(raw_ostream &os, RecordKeeper &keeper) { | |||||
for_each_operator(os, keeper, gen_op_def_c_header_single); | |||||
for_each_operator(os, keeper, gen_to_string_trait_for_enum); | |||||
return false; | |||||
} | |||||
static bool gen_op_def_c_body(raw_ostream &os, RecordKeeper &keeper) { | |||||
for_each_operator(os, keeper, gen_op_def_c_body_single); | |||||
return false; | |||||
} | |||||
static bool gen_op_def_pybind11(raw_ostream &os, RecordKeeper &keeper) { | |||||
EnumContext ctx; | |||||
using namespace std::placeholders; | |||||
for_each_operator(os, keeper, | |||||
std::bind(gen_op_def_pybind11_single, _1, _2, std::ref(ctx))); | |||||
return false; | |||||
} | |||||
static bool gen_op_def_python_c_extension(raw_ostream &os, RecordKeeper &keeper) { | |||||
EnumContext ctx; | |||||
using namespace std::placeholders; | |||||
for_each_operator(os, keeper, | |||||
std::bind(gen_op_def_python_c_extension_single, _1, _2, std::ref(ctx))); | |||||
os << "#define INIT_ALL_OP(m)"; | |||||
for_each_operator(os, keeper, [&](raw_ostream& os, MgbOp& op) { | |||||
os << formatv(" \\\n _init_py_{0}(m);", op.getCppClassName()); | |||||
}); | |||||
os << "\n"; | |||||
return false; | |||||
} | |||||
using namespace mlir::tblgen; | |||||
int main(int argc, char **argv) { | int main(int argc, char **argv) { | ||||
llvm::InitLLVM y(argc, argv); | llvm::InitLLVM y(argc, argv); | ||||
@@ -0,0 +1,40 @@ | |||||
/** | |||||
* \file imperative/tablegen/emitter.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <unordered_map> | |||||
#include <stdexcept> | |||||
#include "llvm/ADT/StringRef.h" | |||||
#include "llvm/Support/raw_ostream.h" | |||||
namespace mlir::tblgen { | |||||
struct Environment { | |||||
std::unordered_map<unsigned int, std::pair<llvm::StringRef, llvm::StringRef>> enumAlias; | |||||
}; | |||||
struct EmitterBase { | |||||
EmitterBase(raw_ostream& os_): os(os_) {} | |||||
EmitterBase(raw_ostream& os_, Environment& env): os(os_), env_p(&env) {} | |||||
protected: | |||||
void newline() { os << "\n"; } | |||||
Environment& env() { | |||||
if (env_p) { | |||||
return *env_p; | |||||
} | |||||
throw std::runtime_error("access global environment via non-environment emitter"); | |||||
} | |||||
raw_ostream& os; | |||||
Environment* env_p = nullptr; | |||||
}; | |||||
} // namespace mlir::tblgen |
@@ -1,3 +1,16 @@ | |||||
/** | |||||
* \file imperative/tablegen/helper.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <iostream> | |||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
@@ -278,5 +291,28 @@ public: | |||||
} | } | ||||
}; | }; | ||||
using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase; | |||||
using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin; | |||||
using MgbHashableAttr = mlir::tblgen::MgbHashableAttrMixin; | |||||
using MgbAliasAttr = mlir::tblgen::MgbAliasAttrMixin; | |||||
using MgbOp = mlir::tblgen::MgbOpBase; | |||||
using MgbHashableOp = mlir::tblgen::MgbHashableOpMixin; | |||||
static inline void foreach_operator(llvm::RecordKeeper &keeper, | |||||
std::function<void(MgbOp&)> callback) { | |||||
auto op_base_class = keeper.getClass("Op"); | |||||
ASSERT(op_base_class, "could not find base class Op"); | |||||
for (auto&& i: keeper.getDefs()) { | |||||
auto&& r = i.second; | |||||
if (r->isSubClassOf(op_base_class)) { | |||||
auto op = mlir::tblgen::Operator(r.get()); | |||||
if (op.getDialectName().str() == "mgb") { | |||||
std::cerr << "\033[34;15m" << "Generating " << r->getName().str() << "\033[0m" << std::endl; | |||||
callback(llvm::cast<MgbOp>(op)); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} // namespace tblgen | } // namespace tblgen | ||||
} // namespace mlir | } // namespace mlir |
@@ -0,0 +1,309 @@ | |||||
/** | |||||
* \file imperative/tablegen/targets/cpp_class.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "./cpp_class.h" | |||||
#include "../emitter.h" | |||||
namespace mlir::tblgen { | |||||
namespace { | |||||
llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) { | |||||
// Note: we have already registered the corresponding attr wrappers | |||||
// for following basic ctypes so we needn't handle them here | |||||
/* auto&& attr_type_name = attr.getAttrDefName(); | |||||
if (attr_type_name == "UI32Attr") { | |||||
return "uint32_t"; | |||||
} | |||||
if (attr_type_name == "UI64Attr") { | |||||
return "uint64_t"; | |||||
} | |||||
if (attr_type_name == "I32Attr") { | |||||
return "int32_t"; | |||||
} | |||||
if (attr_type_name == "F32Attr") { | |||||
return "float"; | |||||
} | |||||
if (attr_type_name == "F64Attr") { | |||||
return "double"; | |||||
} | |||||
if (attr_type_name == "StrAttr") { | |||||
return "std::string"; | |||||
} | |||||
if (attr_type_name == "BoolAttr") { | |||||
return "bool"; | |||||
}*/ | |||||
auto&& attr = llvm::cast<MgbAttrWrapper>(attr_); | |||||
if (auto e = llvm::dyn_cast<MgbEnumAttr>(&attr)) { | |||||
return e->getEnumName(); | |||||
} | |||||
return attr.getUnderlyingType(); | |||||
} | |||||
class OpDefEmitter final: public EmitterBase { | |||||
public: | |||||
OpDefEmitter(MgbOp& op_, raw_ostream& os_): | |||||
EmitterBase(os_), op(op_) {} | |||||
void emit_header(); | |||||
void emit_tpl_spl(); | |||||
void emit_body(); | |||||
private: | |||||
MgbOp& op; | |||||
}; | |||||
void OpDefEmitter::emit_header() { | |||||
os << formatv( | |||||
"class {0} : public OpDefImplBase<{0}> {{\n" | |||||
" MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n" | |||||
"public:\n", | |||||
op.getCppClassName() | |||||
); | |||||
// handle enum alias | |||||
for (auto &&i : op.getMgbAttributes()) { | |||||
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||||
os << formatv( | |||||
" using {0} = {1};\n", | |||||
attr->getEnumName(), attr->getUnderlyingType() | |||||
); | |||||
} | |||||
} | |||||
for (auto &&i : op.getMgbAttributes()) { | |||||
auto defaultValue = i.attr.getDefaultValue().str(); | |||||
if (!defaultValue.empty()) { | |||||
defaultValue = formatv(" = {0}", defaultValue); | |||||
} | |||||
os << formatv( | |||||
" {0} {1}{2};\n", | |||||
attr_to_ctype(i.attr), i.name, defaultValue | |||||
); | |||||
} | |||||
auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) { | |||||
os << formatv( | |||||
" {0}({1}){2}{3}\n", | |||||
op.getCppClassName(), paramList, memInitList, body | |||||
); | |||||
}; | |||||
gen_ctor("", "", " = default;"); | |||||
if (!op.getMgbAttributes().empty()) { | |||||
std::vector<std::string> paramList, initList; | |||||
for (auto &&i : op.getMgbAttributes()) { | |||||
paramList.push_back(formatv( | |||||
"{0} {1}_", attr_to_ctype(i.attr), i.name | |||||
)); | |||||
initList.push_back(formatv( | |||||
"{0}({0}_)", i.name | |||||
)); | |||||
} | |||||
paramList.push_back("std::string scope_ = {}"); | |||||
gen_ctor(llvm::join(paramList, ", "), | |||||
": " + llvm::join(initList, ", "), | |||||
" { set_scope(scope_); }"); | |||||
} | |||||
auto packedParams = op.getPackedParams(); | |||||
if (!packedParams.empty()) { | |||||
std::vector<std::string> paramList, initList; | |||||
for (auto &&p : packedParams) { | |||||
auto&& paramFields = p.getFields(); | |||||
auto&& paramType = p.getFullName(); | |||||
auto&& paramName = formatv("packed_param_{0}", paramList.size()); | |||||
paramList.push_back( | |||||
paramFields.empty() ? paramType.str() | |||||
: formatv("{0} {1}", paramType, paramName) | |||||
); | |||||
for (auto&& i : paramFields) { | |||||
initList.push_back(formatv( | |||||
"{0}({1}.{0})", i.name, paramName | |||||
)); | |||||
} | |||||
} | |||||
for (auto&& i : op.getExtraArguments()) { | |||||
paramList.push_back(formatv( | |||||
"{0} {1}_", attr_to_ctype(i.attr), i.name | |||||
)); | |||||
initList.push_back(formatv( | |||||
"{0}({0}_)", i.name | |||||
)); | |||||
} | |||||
gen_ctor(llvm::join(paramList, ", "), | |||||
initList.empty() ? "" : ": " + llvm::join(initList, ", "), | |||||
" {}"); | |||||
} | |||||
if (!packedParams.empty()) { | |||||
for (auto&& p : packedParams) { | |||||
auto accessor = p.getAccessor(); | |||||
if (!accessor.empty()) { | |||||
os << formatv( | |||||
" {0} {1}() const {{\n", | |||||
p.getFullName(), accessor | |||||
); | |||||
std::vector<llvm::StringRef> fields; | |||||
for (auto&& i : p.getFields()) { | |||||
fields.push_back(i.name); | |||||
} | |||||
os << formatv( | |||||
" return {{{0}};\n", | |||||
llvm::join(fields, ", ") | |||||
); | |||||
os << " }\n"; | |||||
} | |||||
} | |||||
} | |||||
if (auto decl = op.getExtraOpdefDecl()) { | |||||
os << decl.getValue(); | |||||
} | |||||
os << formatv( | |||||
"};\n\n" | |||||
); | |||||
} | |||||
void OpDefEmitter::emit_tpl_spl() { | |||||
for (auto &&i : op.getMgbAttributes()) { | |||||
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||||
if (attr->supportToString()) { | |||||
std::vector<std::string> case_body; | |||||
std::string ename = formatv("{0}::{1}", | |||||
op.getCppClassName(), attr->getEnumName()); | |||||
llvm::for_each(attr->getEnumMembers(), [&](auto&& v){ | |||||
case_body.push_back(formatv( | |||||
"case {0}::{1}: return \"{1}\";", ename, v)); | |||||
}); | |||||
os << formatv(R"( | |||||
template <> | |||||
struct ToStringTrait<{0}> { | |||||
std::string operator()({0} e) const { | |||||
switch (e) { | |||||
{1} | |||||
default: | |||||
return "{0}::Unknown"; | |||||
} | |||||
} | |||||
}; | |||||
)", ename, llvm::join(case_body, "\n")); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
void OpDefEmitter::emit_body() { | |||||
auto&& className = op.getCppClassName(); | |||||
os << formatv( | |||||
"MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className | |||||
); | |||||
auto formatMethImpl = [&](auto&& meth) { | |||||
return formatv( | |||||
"{0}_{1}_impl", className, meth | |||||
); | |||||
}; | |||||
std::vector<std::string> methods; | |||||
if (auto hashable = llvm::dyn_cast<MgbHashableOp>(&op)) { | |||||
os << "namespace {\n"; | |||||
// generate hash() | |||||
mlir::tblgen::FmtContext ctx; | |||||
os << formatv( | |||||
"size_t {0}(const OpDef& def_) {{\n", | |||||
formatMethImpl("hash") | |||||
); | |||||
os << formatv( | |||||
" auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||||
" static_cast<void>(op_);\n", | |||||
className | |||||
); | |||||
ctx.withSelf("op_"); | |||||
os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx); | |||||
os << "}\n"; | |||||
// generate is_same_st() | |||||
os << formatv( | |||||
"bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n", | |||||
formatMethImpl("is_same_st") | |||||
); | |||||
os << formatv( | |||||
" auto &&a_ = lhs_.cast_final_safe<{0}>(),\n" | |||||
" &&b_ = rhs_.cast_final_safe<{0}>();\n" | |||||
" static_cast<void>(a_);\n" | |||||
" static_cast<void>(b_);\n", | |||||
className | |||||
); | |||||
os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_"); | |||||
os << "}\n"; | |||||
// generate props() | |||||
os << formatv( | |||||
"std::vector<std::pair<const char*, std::string>> {0}(const OpDef& def_) {{\n", | |||||
formatMethImpl("props") | |||||
); | |||||
os << formatv( | |||||
" auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||||
" static_cast<void>(op_);\n", | |||||
className | |||||
); | |||||
ctx.withSelf("op_"); | |||||
os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx); | |||||
os << "}\n"; | |||||
// generate make_name() | |||||
os << formatv( | |||||
"std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name") | |||||
); | |||||
os << formatv( | |||||
" auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||||
" static_cast<void>(op_);\n", | |||||
className | |||||
); | |||||
ctx.withSelf("op_"); | |||||
os << mlir::tblgen::tgfmt(op.getNameFunctionTemplate(), &ctx); | |||||
os << "}\n"; | |||||
os << "} // anonymous namespace\n"; | |||||
methods.push_back("hash"); | |||||
methods.push_back("is_same_st"); | |||||
methods.push_back("props"); | |||||
methods.push_back("make_name"); | |||||
} | |||||
if (!methods.empty()) { | |||||
os << formatv( | |||||
"OP_TRAIT_REG({0}, {0})", op.getCppClassName() | |||||
); | |||||
for (auto&& i : methods) { | |||||
os << formatv( | |||||
"\n .{0}({1})", i, formatMethImpl(i) | |||||
); | |||||
} | |||||
os << ";\n\n"; | |||||
} | |||||
} | |||||
} // namespace | |||||
bool gen_op_def_c_header(raw_ostream &os, llvm::RecordKeeper &keeper) { | |||||
foreach_operator(keeper, [&](MgbOp& op) { | |||||
OpDefEmitter emitter(op, os); | |||||
emitter.emit_header(); | |||||
emitter.emit_tpl_spl(); | |||||
}); | |||||
return false; | |||||
} | |||||
bool gen_op_def_c_body(raw_ostream &os, llvm::RecordKeeper &keeper) { | |||||
foreach_operator(keeper, [&](MgbOp& op) { | |||||
OpDefEmitter emitter(op, os); | |||||
emitter.emit_body(); | |||||
}); | |||||
return false; | |||||
} | |||||
} // namespace mlir::tblgen |
@@ -0,0 +1,21 @@ | |||||
/** | |||||
* \file imperative/tablegen/targets/cpp_class.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "../helper.h" | |||||
namespace mlir::tblgen { | |||||
bool gen_op_def_c_header(raw_ostream &os, llvm::RecordKeeper &keeper); | |||||
bool gen_op_def_c_body(raw_ostream &os, llvm::RecordKeeper &keeper); | |||||
} // namespace mlir::tblgen |
@@ -0,0 +1,142 @@ | |||||
/** | |||||
* \file imperative/tablegen/targets/pybind11.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "./pybind11.h" | |||||
#include "../emitter.h" | |||||
namespace mlir::tblgen { | |||||
namespace { | |||||
class OpDefEmitter final: public EmitterBase { | |||||
public: | |||||
OpDefEmitter(MgbOp& op_, raw_ostream& os_, Environment& env_): | |||||
EmitterBase(os_, env_), op(op_) {} | |||||
void emit(); | |||||
private: | |||||
MgbOp& op; | |||||
}; | |||||
void OpDefEmitter::emit() { | |||||
auto className = op.getCppClassName(); | |||||
os << formatv( | |||||
"py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n", | |||||
className | |||||
); | |||||
for (auto&& i : op.getMgbAttributes()) { | |||||
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||||
unsigned int enumID; | |||||
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||||
auto&& aliasBase = alias->getAliasBase(); | |||||
enumID = | |||||
llvm::cast<MgbEnumAttr>(aliasBase) | |||||
.getBaseRecord()->getID(); | |||||
} else { | |||||
enumID = attr->getBaseRecord()->getID(); | |||||
} | |||||
auto&& enumAlias = env().enumAlias; | |||||
auto&& iter = enumAlias.find(enumID); | |||||
if (iter == enumAlias.end()) { | |||||
os << formatv( | |||||
"py::enum_<{0}::{1}>({0}Inst, \"{1}\")", | |||||
className, attr->getEnumName() | |||||
); | |||||
std::vector<std::string> body; | |||||
for (auto&& i: attr->getEnumMembers()) { | |||||
os << formatv( | |||||
"\n .value(\"{2}\", {0}::{1}::{2})", | |||||
className, attr->getEnumName(), i | |||||
); | |||||
body.push_back(formatv( | |||||
"if (str == \"{2}\") return {0}::{1}::{2};", | |||||
className, attr->getEnumName(), i | |||||
)); | |||||
} | |||||
if (attr->getEnumCombinedFlag()) { | |||||
//! define operator | | |||||
os << formatv( | |||||
"\n .def(\"__or__\", []({0}::{1} s0, {0}::{1} s1) {{ " | |||||
"\n return static_cast<{0}::{1}>(uint32_t(s0) | uint32_t(s1));" | |||||
"\n })", | |||||
className, attr->getEnumName()); | |||||
//! define operator & | |||||
os << formatv( | |||||
"\n .def(\"__and__\", []({0}::{1} s0, {0}::{1} s1) {{" | |||||
"\n return static_cast<{0}::{1}>(uint32_t(s0) & uint32_t(s1));" | |||||
"\n })", | |||||
className, attr->getEnumName()); | |||||
} | |||||
os << formatv( | |||||
"\n .def(py::init([](const std::string& in) {" | |||||
"\n auto&& str = normalize_enum(in);" | |||||
"\n {0}" | |||||
"\n throw py::cast_error(\"invalid enum value \" + in);" | |||||
"\n }));\n", | |||||
llvm::join(body, "\n ") | |||||
); | |||||
os << formatv( | |||||
"py::implicitly_convertible<std::string, {0}::{1}>();\n\n", | |||||
className, attr->getEnumName() | |||||
); | |||||
enumAlias.emplace(enumID, | |||||
std::make_pair(className, attr->getEnumName())); | |||||
} else { | |||||
os << formatv( | |||||
"{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n", | |||||
className, attr->getEnumName(), | |||||
iter->second.first, iter->second.second | |||||
); | |||||
} | |||||
} | |||||
} | |||||
// generate op class binding | |||||
os << formatv("{0}Inst", className); | |||||
bool hasDefaultCtor = op.getMgbAttributes().empty(); | |||||
if (!hasDefaultCtor) { | |||||
os << "\n .def(py::init<"; | |||||
std::vector<llvm::StringRef> targs; | |||||
for (auto &&i : op.getMgbAttributes()) { | |||||
targs.push_back(i.attr.getReturnType()); | |||||
} | |||||
os << llvm::join(targs, ", "); | |||||
os << ", std::string>()"; | |||||
for (auto &&i : op.getMgbAttributes()) { | |||||
os << formatv(", py::arg(\"{0}\")", i.name); | |||||
auto defaultValue = i.attr.getDefaultValue(); | |||||
if (!defaultValue.empty()) { | |||||
os << formatv(" = {0}", defaultValue); | |||||
} else { | |||||
hasDefaultCtor = true; | |||||
} | |||||
} | |||||
os << ", py::arg(\"scope\") = {})"; | |||||
} | |||||
if (hasDefaultCtor) { | |||||
os << "\n .def(py::init<>())"; | |||||
} | |||||
for (auto &&i : op.getMgbAttributes()) { | |||||
os << formatv( | |||||
"\n .def_readwrite(\"{0}\", &{1}::{0})", | |||||
i.name, className | |||||
); | |||||
} | |||||
os << ";\n\n"; | |||||
} | |||||
} // namespace | |||||
bool gen_op_def_pybind11(raw_ostream &os, llvm::RecordKeeper &keeper) { | |||||
Environment env; | |||||
using namespace std::placeholders; | |||||
foreach_operator(keeper, [&](MgbOp& op) { | |||||
OpDefEmitter(op, os, env).emit(); | |||||
}); | |||||
return false; | |||||
} | |||||
} // namespace mlir::tblgen |
@@ -0,0 +1,19 @@ | |||||
/** | |||||
* \file imperative/tablegen/targets/pybind11.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "../helper.h" | |||||
namespace mlir::tblgen { | |||||
bool gen_op_def_pybind11(raw_ostream &os, llvm::RecordKeeper &keeper); | |||||
} // namespace mlir::tblgen |
@@ -0,0 +1,314 @@ | |||||
/** | |||||
* \file imperative/tablegen/targets/python_c_extension.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "python_c_extension.h" | |||||
#include "../emitter.h" | |||||
namespace mlir::tblgen { | |||||
namespace { | |||||
struct Initproc { | |||||
std::string func; | |||||
Initproc(std::string&& s): func(std::move(s)) {} | |||||
std::string operator()(std::string argument) { | |||||
return formatv("{0}({1})", func, argument); | |||||
} | |||||
}; | |||||
class OpDefEmitter: public EmitterBase { | |||||
public: | |||||
OpDefEmitter(MgbOp& op_, raw_ostream& os_, Environment& env_): | |||||
EmitterBase(os_, env_), op(op_) { | |||||
ctx.withSelf(op.getCppClassName()); | |||||
} | |||||
Initproc emit(); | |||||
private: | |||||
void emit_class(); | |||||
void emit_py_init(); | |||||
void emit_py_getsetters(); | |||||
Initproc emit_initproc(); | |||||
MgbOp& op; | |||||
std::vector<Initproc> subclasses; | |||||
mlir::tblgen::FmtContext ctx; | |||||
}; | |||||
class EnumAttrEmitter: public EmitterBase { | |||||
public: | |||||
EnumAttrEmitter(llvm::StringRef parent, MgbEnumAttr* attr_, raw_ostream& os_, Environment& env_): | |||||
EmitterBase(os_, env_), attr(attr_) { | |||||
unsigned int enumID; | |||||
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||||
auto&& aliasBase = alias->getAliasBase(); | |||||
enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID(); | |||||
} else { | |||||
enumID = attr->getBaseRecord()->getID(); | |||||
} | |||||
ctx.addSubst("enumTpl", attr->getEnumCombinedFlag() ? "BitCombinedEnumWrapper" : "EnumWrapper"); | |||||
ctx.addSubst("opClass", parent); | |||||
ctx.addSubst("enumClass", attr->getEnumName()); | |||||
firstOccur = env().enumAlias.emplace(enumID, std::make_pair(parent, attr->getEnumName())).second; | |||||
} | |||||
Initproc emit(); | |||||
protected: | |||||
void emit_tpl_spl(); | |||||
Initproc emit_initproc(); | |||||
MgbEnumAttr* attr; | |||||
bool firstOccur; | |||||
mlir::tblgen::FmtContext ctx; | |||||
}; | |||||
Initproc EnumAttrEmitter::emit() { | |||||
emit_tpl_spl(); | |||||
return emit_initproc(); | |||||
} | |||||
void EnumAttrEmitter::emit_tpl_spl() { | |||||
if (!firstOccur) return; | |||||
os << tgfmt( | |||||
"template<> PyTypeObject $enumTpl<$opClass::$enumClass>::type={};\n", | |||||
&ctx); | |||||
os << tgfmt( | |||||
"template<> const char* $enumTpl<$opClass::$enumClass>::name = " | |||||
"\"$opClass.$enumClass\";\n", | |||||
&ctx); | |||||
if (attr->getEnumCombinedFlag()) { | |||||
os << tgfmt( | |||||
"template<> PyNumberMethods " | |||||
"$enumTpl<$opClass::$enumClass>::number_methods={};\n", | |||||
&ctx); | |||||
os << tgfmt(R"( | |||||
template<> struct EnumTrait<$opClass::$enumClass> { | |||||
static constexpr bool is_bit_combined = true; | |||||
static constexpr std::underlying_type_t<$opClass::$enumClass> max = (1llu << $0) - 1; | |||||
}; | |||||
)", &ctx, attr->getEnumMembers().size()); | |||||
} | |||||
auto str2type = [&](auto&& i) -> std::string { | |||||
return tgfmt("{normalize_enum(\"$0\"), $opClass::$enumClass::$0}", &ctx, i); | |||||
}; | |||||
os << tgfmt(R"( | |||||
template<> std::unordered_map<std::string, $opClass::$enumClass> | |||||
$enumTpl<$opClass::$enumClass>::str2type = {$0}; | |||||
)", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), str2type), ", ")); | |||||
auto type2str = [&](auto&& i) -> std::string { | |||||
return tgfmt("{$opClass::$enumClass::$0, normalize_enum(\"$0\")}", &ctx, i); | |||||
}; | |||||
os << tgfmt(R"( | |||||
template<> std::unordered_map<$opClass::$enumClass, std::string> | |||||
$enumTpl<$opClass::$enumClass>::type2str = {$0}; | |||||
)", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), type2str), ", ")); | |||||
} | |||||
Initproc EnumAttrEmitter::emit_initproc() { | |||||
std::string initproc = formatv("_init_py_{0}_{1}", | |||||
ctx.getSubstFor("opClass"), ctx.getSubstFor("enumClass")); | |||||
os << tgfmt(R"( | |||||
void $0(PyTypeObject& py_type) { | |||||
auto& e_type = $enumTpl<$opClass::$enumClass>::type; | |||||
)", &ctx, initproc); | |||||
if (firstOccur) { | |||||
os << tgfmt(R"( | |||||
e_type = {PyVarObject_HEAD_INIT(NULL, 0)}; | |||||
e_type.tp_name = "megengine.core._imperative_rt.ops.$opClass.$enumClass"; | |||||
e_type.tp_basicsize = sizeof($enumTpl<$opClass::$enumClass>); | |||||
e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||||
e_type.tp_doc = "$opClass.$enumClass"; | |||||
e_type.tp_base = &PyBaseObject_Type; | |||||
e_type.tp_repr = $enumTpl<$opClass::$enumClass>::py_repr; | |||||
e_type.tp_richcompare = $enumTpl<$opClass::$enumClass>::tp_richcompare; | |||||
)", &ctx); | |||||
if (attr->getEnumCombinedFlag()) { | |||||
// only bit combined enum could new instance because bitwise operation, | |||||
// others should always use singleton | |||||
os << tgfmt(R"( | |||||
e_type.tp_new = $enumTpl<$opClass::$enumClass>::py_new_combined_enum; | |||||
auto& number_method = $enumTpl<$opClass::$enumClass>::number_methods; | |||||
number_method.nb_or = $enumTpl<$opClass::$enumClass>::py_or; | |||||
number_method.nb_and = $enumTpl<$opClass::$enumClass>::py_and; | |||||
e_type.tp_as_number = &number_method; | |||||
)", &ctx); | |||||
} | |||||
os << " mgb_assert(PyType_Ready(&e_type) >= 0);\n"; | |||||
for (auto&& i : attr->getEnumMembers()) { | |||||
os << tgfmt(R"({ | |||||
PyObject* inst = e_type.tp_alloc(&e_type, 0); | |||||
reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0; | |||||
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "$0", inst) >= 0); | |||||
PyType_Modified(&e_type); | |||||
})", &ctx, i); | |||||
} | |||||
} | |||||
os << tgfmt(R"( | |||||
mgb_assert(PyDict_SetItemString( | |||||
py_type.tp_dict, "$enumClass", reinterpret_cast<PyObject*>(&e_type)) >= 0); | |||||
)", &ctx); | |||||
os << "}\n"; | |||||
return initproc; | |||||
} | |||||
Initproc OpDefEmitter::emit() { | |||||
for (auto&& i : op.getMgbAttributes()) { | |||||
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||||
subclasses.push_back(EnumAttrEmitter(op.getCppClassName(), attr, os, env()).emit()); | |||||
} | |||||
} | |||||
emit_class(); | |||||
emit_py_init(); | |||||
emit_py_getsetters(); | |||||
return emit_initproc(); | |||||
} | |||||
void OpDefEmitter::emit_class() { | |||||
os << tgfmt(R"( | |||||
PyOpDefBegin($_self) // { | |||||
static PyGetSetDef py_getsetters[]; | |||||
static int py_init(PyObject *self, PyObject *args, PyObject *kwds); | |||||
// }; | |||||
PyOpDefEnd($_self) | |||||
)", &ctx); | |||||
} | |||||
void OpDefEmitter::emit_py_init() { | |||||
std::string initBody; | |||||
if (!op.getMgbAttributes().empty()) { | |||||
initBody += "static const char* kwlist[] = {"; | |||||
std::vector<llvm::StringRef> attr_name_list; | |||||
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||||
attr_name_list.push_back(attr.name); | |||||
}); | |||||
attr_name_list.push_back("scope"); | |||||
llvm::for_each(attr_name_list, [&](auto&& attr) { | |||||
initBody += formatv("\"{0}\", ", attr); | |||||
}); | |||||
initBody += "NULL};\n"; | |||||
initBody += " PyObject "; | |||||
auto initializer = [&](auto&& attr) -> std::string { | |||||
return formatv("*{0} = NULL", attr); | |||||
}; | |||||
initBody += llvm::join(llvm::map_range(attr_name_list, initializer), ", ") + ";\n"; | |||||
initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|"; | |||||
// an extra slot created for name | |||||
initBody += std::string(attr_name_list.size(), 'O'); | |||||
initBody += "\", const_cast<char**>(kwlist)"; | |||||
llvm::for_each(attr_name_list, [&](auto&& attr) { | |||||
initBody += formatv(", &{0}", attr); | |||||
}); | |||||
initBody += "))\n"; | |||||
initBody += " return -1;\n"; | |||||
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||||
initBody += tgfmt(R"( | |||||
if ($0) { | |||||
try { | |||||
reinterpret_cast<PyOp($_self)*>(self)->inst().$0 = | |||||
pyobj_convert_generic<decltype($_self::$0)>::from($0); | |||||
} CATCH_ALL(-1) | |||||
} | |||||
)", &ctx, attr.name); | |||||
}); | |||||
initBody += tgfmt(R"( | |||||
if (scope) { | |||||
try { | |||||
reinterpret_cast<PyOp(OpDef)*>(self)->op | |||||
->set_scope(pyobj_convert_generic<std::string>::from(scope)); | |||||
} CATCH_ALL(-1) | |||||
} | |||||
)", &ctx); | |||||
} | |||||
initBody += "\n return 0;"; | |||||
os << tgfmt(R"( | |||||
int PyOp($_self)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { | |||||
$0 | |||||
} | |||||
)", &ctx, initBody); | |||||
} | |||||
void OpDefEmitter::emit_py_getsetters() { | |||||
auto f = [&](auto&& attr) -> std::string { | |||||
return tgfmt( | |||||
"{const_cast<char*>(\"$0\"), py_get_generic($_self, $0), py_set_generic($_self, $0), const_cast<char*>(\"$0\"), NULL},", | |||||
&ctx, attr.name); | |||||
}; | |||||
os << tgfmt(R"( | |||||
PyGetSetDef PyOp($_self)::py_getsetters[] = { | |||||
$0 | |||||
{NULL} /* Sentinel */ | |||||
}; | |||||
)", &ctx, llvm::join(llvm::map_range(op.getMgbAttributes(), f), "\n ")); | |||||
} | |||||
Initproc OpDefEmitter::emit_initproc() { | |||||
std::string initproc = formatv("_init_py_{0}", op.getCppClassName()); | |||||
std::string subclass_init_call; | |||||
for (auto&& i : subclasses) { | |||||
subclass_init_call += formatv(" {0};\n", i("py_type")); | |||||
} | |||||
os << tgfmt(R"( | |||||
void $0(py::module m) { | |||||
using py_op = PyOp($_self); | |||||
auto& py_type = PyOpType($_self); | |||||
py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; | |||||
py_type.tp_name = "megengine.core._imperative_rt.ops.$_self"; | |||||
py_type.tp_basicsize = sizeof(PyOp($_self)); | |||||
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||||
py_type.tp_doc = "$_self"; | |||||
py_type.tp_base = &PyOpType(OpDef); | |||||
py_type.tp_dealloc = py_dealloc_generic<py_op>; | |||||
py_type.tp_new = py_new_generic<py_op>; | |||||
py_type.tp_init = py_op::py_init; | |||||
py_type.tp_getset = py_op::py_getsetters; | |||||
mgb_assert(PyType_Ready(&py_type) >= 0); | |||||
$1 | |||||
PyType_Modified(&py_type); | |||||
m.add_object("$_self", reinterpret_cast<PyObject*>(&py_type)); | |||||
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace($_self::typeinfo(), &py_type).second); | |||||
} | |||||
)", &ctx, initproc, subclass_init_call); | |||||
return initproc; | |||||
} | |||||
} // namespace | |||||
bool gen_op_def_python_c_extension(raw_ostream &os, llvm::RecordKeeper &keeper) { | |||||
Environment env; | |||||
using namespace std::placeholders; | |||||
std::vector<Initproc> initprocs; | |||||
foreach_operator(keeper, [&](MgbOp& op) { | |||||
initprocs.emplace_back(OpDefEmitter(op, os, env).emit()); | |||||
}); | |||||
os << "#define INIT_ALL_OP(m)"; | |||||
for(auto&& init : initprocs) { | |||||
os << formatv(" \\\n {0};", init("m")); | |||||
} | |||||
os << "\n"; | |||||
return false; | |||||
} | |||||
} // namespace mlir::tblgen |
@@ -0,0 +1,19 @@ | |||||
/** | |||||
* \file imperative/tablegen/targets/python_c_extension.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "../helper.h" | |||||
namespace mlir::tblgen { | |||||
bool gen_op_def_python_c_extension(raw_ostream &os, llvm::RecordKeeper &keeper); | |||||
} // namespace mlir::tblgen |
@@ -709,7 +709,7 @@ void run_test_st(Args &env) { | |||||
strategy = S::PROFILE; | strategy = S::PROFILE; | ||||
} | } | ||||
} else if (env.use_fast_run) { | } else if (env.use_fast_run) { | ||||
strategy = S::PROFILE | S::OPTMIZED; | |||||
strategy = S::PROFILE | S::OPTIMIZED; | |||||
} else if (env.reproducible) { | } else if (env.reproducible) { | ||||
strategy = S::HEURISTIC | S::REPRODUCIBLE; | strategy = S::HEURISTIC | S::REPRODUCIBLE; | ||||
} | } | ||||
@@ -365,14 +365,16 @@ namespace mgb { | |||||
if (!m_free_task_block.empty()) { | if (!m_free_task_block.empty()) { | ||||
ret = std::move(m_free_task_block.back()); | ret = std::move(m_free_task_block.back()); | ||||
m_free_task_block.pop_back(); | m_free_task_block.pop_back(); | ||||
break; | |||||
} else if (m_block_quota > 0) { | } else if (m_block_quota > 0) { | ||||
ret = std::make_unique<TaskBlock>(); | ret = std::make_unique<TaskBlock>(); | ||||
m_block_quota--; | m_block_quota--; | ||||
break; | |||||
} else { | } else { | ||||
m_cv.wait(m_mutex); | m_cv.wait(m_mutex); | ||||
continue; | continue; | ||||
} | } | ||||
} while (false); | |||||
} while (true); | |||||
ret->first_tid = m_new_block_first_tid; | ret->first_tid = m_new_block_first_tid; | ||||
m_new_block_first_tid += BLOCK_SIZE; | m_new_block_first_tid += BLOCK_SIZE; | ||||
ret->prev = prev; | ret->prev = prev; | ||||
@@ -12,8 +12,8 @@ | |||||
#pragma once | #pragma once | ||||
#define MGB_MAJOR 8 | #define MGB_MAJOR 8 | ||||
#define MGB_MINOR 9999 | |||||
#define MGB_PATCH 0 | |||||
#define MGB_MINOR 10 | |||||
#define MGB_PATCH 1 | |||||
//! whether it is development version | //! whether it is development version | ||||
#ifndef MGB_IS_DEV | #ifndef MGB_IS_DEV | ||||
#define MGB_IS_DEV 0 | #define MGB_IS_DEV 0 | ||||
@@ -1565,7 +1565,7 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||||
if (new_inp[i]->shape()[1] % 4 != 0) { | if (new_inp[i]->shape()[1] % 4 != 0) { | ||||
can_exec_cd4 = false; | can_exec_cd4 = false; | ||||
} | } | ||||
//! cd4 elemwise with scaler is supported | |||||
//! cd4 elemwise with scaler is unsupported | |||||
} else if (!new_inp[i]->shape().is_scalar()) { | } else if (!new_inp[i]->shape().is_scalar()) { | ||||
can_exec_cd4 = false; | can_exec_cd4 = false; | ||||
} | } | ||||
@@ -1627,6 +1627,7 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||||
replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_chw; | replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_chw; | ||||
replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_chw; | replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_chw; | ||||
replace_func[opr::AxisAddRemove::typeinfo()] = relayout_inp_to_chw; | replace_func[opr::AxisAddRemove::typeinfo()] = relayout_inp_to_chw; | ||||
replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr; | |||||
replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr; | replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr; | ||||
replace_func[opr::WarpPerspectiveForward::typeinfo()] = | replace_func[opr::WarpPerspectiveForward::typeinfo()] = | ||||
replace_warp_perspective_opr; | replace_warp_perspective_opr; | ||||
@@ -1265,6 +1265,55 @@ TEST(TestGoptInference, ConvertFormatNHWCD4Elemwise) { | |||||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | ||||
} | } | ||||
TEST(TestGoptInference, ConvertFormatNHWCD4TypeCvt) { | |||||
NaiveMegDNNHandleScope naive_megdnn_handle; | |||||
HostTensorGenerator<> gen; | |||||
auto cn = CompNode::load("cpu0"); | |||||
auto graph = ComputingGraph::make(); | |||||
graph->options().graph_opt_level = 0; | |||||
auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||||
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||||
.rename(name); | |||||
}; | |||||
auto host_x = gen({8, 8, 8, 8}, cn); | |||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x); | |||||
opr::Convolution::Param param; | |||||
param.pad_h = param.pad_w = 0; | |||||
auto w1 = mkcvar("w1", {8, 8, 3, 3}), | |||||
conv1 = opr::Convolution::make(x, w1, param), | |||||
tcvt1 = opr::TypeCvt::make(conv1, dtype::Float16()); | |||||
auto w2 = mkcvar("w2", {8, 8, 3, 3}), | |||||
conv2 = opr::Convolution::make(x, w2, param), | |||||
tcvt2 = opr::TypeCvt::make(conv2, dtype::Float16()); | |||||
auto y = opr::Elemwise::make({tcvt1, tcvt2}, opr::Elemwise::Param::Mode::ADD); | |||||
SymbolVar y_opt; | |||||
auto options = gopt::OptimizeForInferenceOptions{}; | |||||
options.enable_nhwcd4(); | |||||
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||||
ASSERT_EQ(opr::Convolution::Param::Format::NHWCD4, | |||||
find_opr<opr::Convolution>(y_opt).param().format); | |||||
graph->compile({{y_opt, {}}}) | |||||
->to_json() | |||||
->writeto_fpath(output_file( | |||||
"TestGoptInference.ConvertFormatNHWCD4TypeCvt.json")); | |||||
HostTensorND host_y_opt, host_y; | |||||
auto func = graph->compile({make_callback_copy(y, host_y), | |||||
make_callback_copy(y_opt, host_y_opt)}); | |||||
func->execute(); | |||||
MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt); | |||||
*host_x = *gen({8, 8, 16, 16}, cn); | |||||
func->execute(); | |||||
MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt); | |||||
} | |||||
TEST(TestGoptInference, ConvertFormatNHWCD4LOCAL) { | TEST(TestGoptInference, ConvertFormatNHWCD4LOCAL) { | ||||
// hwcd4 is only supported in naive handle | // hwcd4 is only supported in naive handle | ||||
NaiveMegDNNHandleScope naive_megdnn_handle; | NaiveMegDNNHandleScope naive_megdnn_handle; | ||||
@@ -1707,8 +1756,8 @@ TEST(TestGoptInference, FastProfileCache) { | |||||
using S = opr::Convolution::ExecutionPolicy::Strategy; | using S = opr::Convolution::ExecutionPolicy::Strategy; | ||||
ASSERT_EQ(S::HEURISTIC, conv.execution_policy_transient().strategy); | ASSERT_EQ(S::HEURISTIC, conv.execution_policy_transient().strategy); | ||||
gopt::modify_opr_algo_strategy_inplace({z + 2.3f}, | gopt::modify_opr_algo_strategy_inplace({z + 2.3f}, | ||||
S::PROFILE | S::OPTMIZED); | |||||
ASSERT_EQ(S::PROFILE | S::OPTMIZED, conv.execution_policy().strategy); | |||||
S::PROFILE | S::OPTIMIZED); | |||||
ASSERT_EQ(S::PROFILE | S::OPTIMIZED, conv.execution_policy().strategy); | |||||
} | } | ||||
TEST(TestGoptInference, AlgoWorkspaceLimit) { | TEST(TestGoptInference, AlgoWorkspaceLimit) { | ||||
@@ -6,7 +6,7 @@ decl_opr('Convolution', | |||||
'convolution kernel in ' | 'convolution kernel in ' | ||||
'(out channel, in channel, kern row, kern col) format')], | '(out channel, in channel, kern row, kern col) format')], | ||||
params=[('param', 'ConvolutionV0'), | params=[('param', 'ConvolutionV0'), | ||||
('execution_polity', 'ExecutionPolicy')], | |||||
('execution_polity', 'ExecutionPolicyV0')], | |||||
desc='batched convolution on channeled 2D images') | desc='batched convolution on channeled 2D images') | ||||
decl_opr('Convolution', | decl_opr('Convolution', | ||||
@@ -28,7 +28,7 @@ decl_opr('ConvolutionBackwardData', | |||||
'convolution kernel in ' | 'convolution kernel in ' | ||||
'(out channel, in channel, kern row, kern col) format')], | '(out channel, in channel, kern row, kern col) format')], | ||||
params=[('param', 'ConvolutionV0'), | params=[('param', 'ConvolutionV0'), | ||||
('execution_polity', 'ExecutionPolicy')], | |||||
('execution_polity', 'ExecutionPolicyV0')], | |||||
body=[ | body=[ | ||||
'a, b = all_inputs', | 'a, b = all_inputs', | ||||
'all_inputs = [b, a]' | 'all_inputs = [b, a]' | ||||
@@ -201,7 +201,7 @@ decl_opr('ConvBiasForward', | |||||
Doc('bias', 'bias'), | Doc('bias', 'bias'), | ||||
], | ], | ||||
params=[('param', 'ConvBiasV1'), | params=[('param', 'ConvBiasV1'), | ||||
('execution_policy', 'ExecutionPolicy')], | |||||
('execution_policy', 'ExecutionPolicyV0')], | |||||
desc=('activation(convolution(src, filter) + bias) with specified ' | desc=('activation(convolution(src, filter) + bias) with specified ' | ||||
'dtype'), | 'dtype'), | ||||
has_out_dtype=True) | has_out_dtype=True) | ||||
@@ -283,7 +283,7 @@ std::vector<megdnn::Algorithm::SearchItem> flatten_search_space( | |||||
static bool algo_attribute_match_strategy(AlgoAttribute attribute, | static bool algo_attribute_match_strategy(AlgoAttribute attribute, | ||||
ExecutionStrategy selected_strategy) { | ExecutionStrategy selected_strategy) { | ||||
bool ret = true; | bool ret = true; | ||||
if (selected_strategy & ExecutionStrategy::OPTMIZED) { | |||||
if (selected_strategy & ExecutionStrategy::OPTIMIZED) { | |||||
ret &= (!static_cast<bool>(AlgoAttribute::NAIVE & attribute)); | ret &= (!static_cast<bool>(AlgoAttribute::NAIVE & attribute)); | ||||
} else if (selected_strategy & ExecutionStrategy::REPRODUCIBLE) { | } else if (selected_strategy & ExecutionStrategy::REPRODUCIBLE) { | ||||
ret &= static_cast<bool>(AlgoAttribute::REPRODUCIBLE & attribute); | ret &= static_cast<bool>(AlgoAttribute::REPRODUCIBLE & attribute); | ||||
@@ -357,7 +357,7 @@ TEST(TestOprDNN, ConvBiasExePolicy) { | |||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
for (auto strategy : | for (auto strategy : | ||||
SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, | SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, | ||||
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) { | |||||
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTIMIZED}) { | |||||
#else | #else | ||||
for (auto strategy : | for (auto strategy : | ||||
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { | SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { | ||||
@@ -444,7 +444,7 @@ TEST(TestOprDNN, ConvolutionExePolicy) { | |||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
for (auto strategy : | for (auto strategy : | ||||
SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, | SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, | ||||
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) { | |||||
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTIMIZED}) { | |||||
#else | #else | ||||
for (auto strategy : | for (auto strategy : | ||||
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { | SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { | ||||
@@ -1717,7 +1717,7 @@ TEST(TestOprDNN, LocalShareForwardExecPolicy) { | |||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
for (auto strategy : | for (auto strategy : | ||||
SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, | SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, | ||||
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) { | |||||
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTIMIZED}) { | |||||
#else | #else | ||||
for (auto strategy : | for (auto strategy : | ||||
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { | SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { | ||||
@@ -1828,7 +1828,7 @@ TEST(TestOprDNN, DeformableConvForward) { | |||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
for (auto strategy : | for (auto strategy : | ||||
SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, | SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, | ||||
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) { | |||||
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTIMIZED}) { | |||||
#else | #else | ||||
for (auto strategy : | for (auto strategy : | ||||
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { | SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { | ||||
@@ -1997,7 +1997,7 @@ TEST(TestOprDNN, BatchConvBiasForward) { | |||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
for (auto strategy : | for (auto strategy : | ||||
SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, | SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, | ||||
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) { | |||||
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTIMIZED}) { | |||||
#else | #else | ||||
for (auto strategy : | for (auto strategy : | ||||
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { | SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { | ||||
@@ -290,11 +290,8 @@ ExternCOprRunner::ExternCOprRunner(std::string& name, | |||||
m_dump_name{name}, | m_dump_name{name}, | ||||
m_param{nullptr} { | m_param{nullptr} { | ||||
mgb_assert(m_desc->size == sizeof(MGBOprDesc), | mgb_assert(m_desc->size == sizeof(MGBOprDesc), | ||||
"invalid MGBOprDesc size: expect=%zu got=%u, may caused by " | |||||
"extern_c_opr.h mismatch, please confirm that the " | |||||
"extern_c_opr.h used when compiling the loader is consistent " | |||||
"with the runtime caller build used", | |||||
sizeof(MGBOprDesc), m_desc->size); | |||||
"invalid MGBOprDesc size: expect=%zu got=%u", sizeof(MGBOprDesc), | |||||
m_desc->size); | |||||
for (auto i : inputs) { | for (auto i : inputs) { | ||||
add_input({i}); | add_input({i}); | ||||
} | } | ||||
@@ -0,0 +1,8 @@ | |||||
add_custom_command( | |||||
OUTPUT link_sh | |||||
COMMAND ${CMAKE_COMMAND} -E create_symlink | |||||
${PROJECT_SOURCE_DIR}/tools/mlir/mgb-file-check/mgb-file-check.sh | |||||
${PROJECT_BINARY_DIR}/tools/mlir/mgb-file-check/mgb-file-check | |||||
) | |||||
add_custom_target(mgb-file-check DEPENDS link_sh) |
@@ -0,0 +1,3 @@ | |||||
#!/bin/bash -e | |||||
FileCheck --enable-var-scope --dump-input=fail "$@" |
@@ -0,0 +1,23 @@ | |||||
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) | |||||
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) | |||||
set(LIBS | |||||
${dialect_libs} | |||||
${conversion_libs} | |||||
LLVMSupport | |||||
MLIROptLib | |||||
MLIRIR | |||||
MLIRPass | |||||
MLIRSupport | |||||
) | |||||
add_executable(mgb-opt mgb-opt.cpp) | |||||
target_include_directories( | |||||
mgb-opt | |||||
PRIVATE ${MLIR_LLVM_INCLUDE_DIR} ${PROJECT_SOURCE_DIR}/src/jit/include | |||||
${PROJECT_BINARY_DIR}/src/jit/include) | |||||
add_dependencies(mgb-opt mgb_dialect) | |||||
target_link_libraries(mgb-opt PRIVATE ${LIBS} megbrain megdnn ${MGE_CUDA_LIBS}) | |||||
llvm_update_compile_flags(mgb-opt) |
@@ -0,0 +1,85 @@ | |||||
/** | |||||
* \file tools/mlir/mgb-opt/mgb-opt.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain/jit/mlir/ir/dialect.h" | |||||
#include "megbrain/jit/mlir/ir/passes.h" | |||||
#include <llvm/Support/CommandLine.h> | |||||
#include <llvm/Support/InitLLVM.h> | |||||
#include <llvm/Support/PrettyStackTrace.h> | |||||
#include <llvm/Support/SourceMgr.h> | |||||
#include <llvm/Support/ToolOutputFile.h> | |||||
#include <mlir/Dialect/Affine/IR/AffineOps.h> | |||||
#include <mlir/Dialect/LLVMIR/LLVMDialect.h> | |||||
#include <mlir/IR/AsmState.h> | |||||
#include <mlir/InitAllDialects.h> | |||||
#include <mlir/InitAllPasses.h> | |||||
#include <mlir/Pass/Pass.h> | |||||
#include <mlir/Pass/PassManager.h> | |||||
#include <mlir/Support/FileUtilities.h> | |||||
#include <mlir/Support/MlirOptMain.h> | |||||
using namespace llvm; | |||||
using namespace mlir; | |||||
//! TODO: Implement a custom MlirOptMain that supports the following flags. | |||||
static cl::opt<bool> print_mlir{ | |||||
"print-mlir", | |||||
cl::desc("Prints MLIR IR after translation"), | |||||
cl::init(false), | |||||
}; | |||||
static cl::list<std::string> input_values{ | |||||
"input-value", | |||||
cl::desc("Input shapes and optional values"), | |||||
cl::ZeroOrMore, | |||||
}; | |||||
static cl::opt<std::string> input_values_file{ | |||||
"input-value-file", | |||||
cl::desc("Provides a file for input shapes and optional values (see " | |||||
"ParseToVariantListFromFile in vm_util.h for details)"), | |||||
cl::init(""), | |||||
}; | |||||
static cl::opt<bool> run{ | |||||
"run", | |||||
cl::desc("Runs the module (vs. just compiling and verifing)"), | |||||
cl::init(true), | |||||
}; | |||||
static cl::list<std::string> run_args{ | |||||
"run-arg", | |||||
cl::desc("Argument passed to the execution flag parser"), | |||||
cl::ZeroOrMore, | |||||
}; | |||||
namespace mgb { | |||||
namespace jit { | |||||
void register_test_mgb_to_affine_lowering_pass(); | |||||
void register_test_affine_to_llvm_lowering_pass(); | |||||
} // namespace jit | |||||
} // namespace mgb | |||||
int main(int argc, char** argv) { | |||||
mlir::registerAllPasses(); | |||||
mlir::DialectRegistry registry; | |||||
mlir::registerAllDialects(registry); | |||||
registry.insert<mgb::jit::MgbDialect>(); | |||||
mgb::jit::register_test_mgb_to_affine_lowering_pass(); | |||||
mgb::jit::register_test_affine_to_llvm_lowering_pass(); | |||||
return failed(MlirOptMain(argc, argv, "MLIR modular optimizer driver", registry)); | |||||
} |
@@ -41,8 +41,14 @@ pdef('PersistentOutputStorage').add_fields( | |||||
Doc('REPRODUCIBLE', | Doc('REPRODUCIBLE', | ||||
'when profile or heuristic algo selection it require the algos' | 'when profile or heuristic algo selection it require the algos' | ||||
'must be reproducible'), | 'must be reproducible'), | ||||
Doc('OPTMIZED', | |||||
'profile require algos are optmized to achieve fast-profile')). | |||||
Doc('OPTIMIZED', | |||||
'profile require algos are optmized to achieve fast-profile'), | |||||
default=('HEURISTIC',), | |||||
member_alias=[(('HEURISTIC', 'REPRODUCIBLE'), 'HEURISTIC_REPRODUCIBLE'), | |||||
(('PROFILE', 'REPRODUCIBLE'), 'PROFILE_REPRODUCIBLE'), | |||||
(('PROFILE', 'HEURISTIC'), 'PROFILE_HEURISTIC'), | |||||
(('OPTIMIZED',), 'OPTMIZED'), | |||||
]). | |||||
add_fields('uint64', | add_fields('uint64', | ||||
Doc('workspace_limit', 'workspace limit in bytes'), | Doc('workspace_limit', 'workspace limit in bytes'), | ||||
str(2**64-1)+'ull')) | str(2**64-1)+'ull')) | ||||