GitOrigin-RevId: d2f22ad5fe
release-1.2
@@ -230,6 +230,10 @@ endif() | |||
# FIXME At present, there are some conflicts between the LLVM that halide | |||
# depends on and the LLVM that MLIR depends on. Should be fixed in subsequent | |||
# versions. | |||
if(MGE_BUILD_IMPERATIVE_RT) | |||
set(MGE_WITH_HALIDE OFF) | |||
message(WARNING "cannot use HALIDE when building IMPERATIVE_RT") | |||
endif() | |||
if(MGE_WITH_JIT_MLIR) | |||
if(MGE_WITH_HALIDE) | |||
message(FATAL_ERROR "please set MGE_WITH_HALIDE to OFF with MGE_WITH_JIT_MLIR enabled") | |||
@@ -310,7 +314,7 @@ if(MGE_INFERENCE_ONLY) | |||
set(MGE_BUILD_IMPERATIVE_RT OFF) | |||
endif() | |||
if(MGE_WITH_JIT_MLIR) | |||
if(MGE_WITH_JIT_MLIR OR MGE_BUILD_IMPERATIVE_RT) | |||
include(cmake/llvm-project.cmake) | |||
endif() | |||
@@ -750,7 +754,7 @@ target_include_directories(mgb_opr_param_defs | |||
add_dependencies(mgb_opr_param_defs _mgb_opr_param_defs) | |||
install(TARGETS mgb_opr_param_defs EXPORT ${MGE_EXPORT_TARGETS}) | |||
if(MGE_WITH_JIT_MLIR) | |||
if(MGE_WITH_JIT_MLIR OR MGE_BUILD_IMPERATIVE_RT) | |||
# generate param_defs.td | |||
set(MGE_GENFILE_DIR ${PROJECT_BINARY_DIR}/src/genfiles) | |||
set(MGE_GEN_IR_DIR ${PROJECT_BINARY_DIR}/src/core/include/megbrain/ir) | |||
@@ -800,12 +804,6 @@ if(TARGET _imperative_rt) | |||
COMMAND ${CMAKE_COMMAND} -E create_symlink | |||
${CMAKE_CURRENT_BINARY_DIR}/imperative/python/${PACKAGE_NAME}/core/$<TARGET_FILE_NAME:${MODULE_NAME}> | |||
${CMAKE_CURRENT_SOURCE_DIR}/imperative/python/${PACKAGE_NAME}/core/$<TARGET_FILE_NAME:${MODULE_NAME}> | |||
COMMAND ${CMAKE_COMMAND} -E create_symlink | |||
${CMAKE_CURRENT_BINARY_DIR}/imperative/python/${PACKAGE_NAME}/core/ops/_internal/generated_ops.py | |||
${CMAKE_CURRENT_SOURCE_DIR}/imperative/python/${PACKAGE_NAME}/core/ops/_internal/generated_ops.py | |||
COMMAND ${CMAKE_COMMAND} -E create_symlink | |||
${CMAKE_CURRENT_BINARY_DIR}/imperative/python/${PACKAGE_NAME}/core/ops/_internal/param_defs.py | |||
${CMAKE_CURRENT_SOURCE_DIR}/imperative/python/${PACKAGE_NAME}/core/ops/_internal/param_defs.py | |||
DEPENDS _imperative_rt | |||
VERBATIM | |||
) | |||
@@ -0,0 +1,150 @@ | |||
#!/usr/bin/env python3 | |||
# -*- coding: utf-8 -*- | |||
import argparse | |||
import collections | |||
import textwrap | |||
import os | |||
import hashlib | |||
import struct | |||
import io | |||
from gen_param_defs import member_defs, ParamDef, IndentWriterBase | |||
class ConverterWriter(IndentWriterBase): | |||
_skip_current_param = False | |||
_last_param = None | |||
_current_tparams = None | |||
_packed = None | |||
_const = None | |||
def __call__(self, fout, defs): | |||
super().__call__(fout) | |||
self._write("// %s", self._get_header()) | |||
self._write("#ifndef MGB_PARAM") | |||
self._write("#define MGB_PARAM") | |||
self._process(defs) | |||
self._write("#endif // MGB_PARAM") | |||
def _ctype2attr(self, ctype, value): | |||
if ctype == 'uint32_t': | |||
return 'MgbUI32Attr', value | |||
if ctype == 'uint64_t': | |||
return 'MgbUI64Attr', value | |||
if ctype == 'int32_t': | |||
return 'MgbI32Attr', value | |||
if ctype == 'float': | |||
return 'MgbF32Attr', value | |||
if ctype == 'double': | |||
return 'MgbF64Attr', value | |||
if ctype == 'bool': | |||
return 'MgbBoolAttr', value | |||
if ctype == 'DTypeEnum': | |||
self._packed = False | |||
return 'MgbDTypeAttr', 'megdnn::DType::from_enum(megdnn::{})'.format(value) | |||
raise RuntimeError("unknown ctype") | |||
def _on_param_begin(self, p): | |||
self._last_param = p | |||
if p.is_legacy: | |||
self._skip_current_param = True | |||
return | |||
self._packed = True | |||
self._current_tparams = [] | |||
self._const = set() | |||
def _on_param_end(self, p): | |||
if self._skip_current_param: | |||
self._skip_current_param = False | |||
return | |||
if self._packed: | |||
self._write("class {0}ParamBase<string accessor> : MgbPackedParamBase<\"{0}\", accessor> {{".format(p.name), indent=1) | |||
else: | |||
self._write("def {0}Param: MgbParamBase<\"{0}\"> {{".format(p.name), indent=1) | |||
self._write("let fields = (ins", indent=1) | |||
self._write(",\n{}".format(self._cur_indent).join(self._current_tparams)) | |||
self._write(");", indent=-1) | |||
self._write("}\n", indent=-1) | |||
if self._packed: | |||
self._write("def {0}Param : {0}ParamBase<\"param\">;\n".format(p.name)) | |||
self._current_tparams = None | |||
self._packed = None | |||
self._const = None | |||
def _wrapped_with_default_value(self, attr, default): | |||
return 'MgbDefaultValuedAttr<{}, \"{}\">'.format(attr, default) | |||
def _on_member_enum(self, e): | |||
p = self._last_param | |||
# Note: always generate llvm Record def for enum attribute even it was not | |||
# directly used by any operator, or other enum couldn't alias to this enum | |||
td_class = "{}{}".format(p.name, e.name) | |||
fullname = "::megdnn::param::{}".format(p.name) | |||
enum_def = "MgbEnumAttr<\"{}\", \"{}\", [".format(fullname, e.name) | |||
def format(v): | |||
return '\"{}\"'.format(str(v)) | |||
enum_def += ','.join(format(i) for i in e.members) | |||
enum_def += "]>" | |||
self._write("def {} : {};".format(td_class, enum_def)) | |||
if self._skip_current_param: | |||
return | |||
# wrapped with default value | |||
default_val = "static_cast<{}::{}>({})".format(fullname, e.name, e.default) | |||
wrapped = self._wrapped_with_default_value(td_class, default_val) | |||
self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) | |||
def _on_member_enum_alias(self, e): | |||
p = self._last_param | |||
if self._skip_current_param: | |||
return | |||
# write enum attr def | |||
td_class = "{}{}".format(p.name, e.name) | |||
fullname = "::megdnn::param::{}".format(p.name) | |||
base_td_class = "{}{}".format(e.src_class, e.src_name) | |||
enum_def = "MgbEnumAliasAttr<\"{}\", \"{}\", {}>".format(fullname, e.name, base_td_class) | |||
self._write("def {} : {};".format(td_class, enum_def)) | |||
# wrapped with default value | |||
default_val = "static_cast<{}::{}>({})".format(fullname, e.name, e.get_default()) | |||
wrapped = self._wrapped_with_default_value(td_class, default_val) | |||
self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) | |||
def _on_member_field(self, f): | |||
if self._skip_current_param: | |||
return | |||
attr, value = self._ctype2attr(f.dtype.cname, str(f.default)) | |||
if str(value) in self._const: | |||
value = '::megdnn::param::{}::{}'.format(self._last_param.name, value) | |||
wrapped = self._wrapped_with_default_value(attr, value) | |||
self._current_tparams.append("{}:${}".format(wrapped, f.name)) | |||
def _on_const_field(self, f): | |||
self._const.add(str(f.name)) | |||
def main(): | |||
parser = argparse.ArgumentParser('generate op param tablegen file') | |||
parser.add_argument('input') | |||
parser.add_argument('output') | |||
args = parser.parse_args() | |||
with open(args.input) as fin: | |||
inputs = fin.read() | |||
exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc}) | |||
input_hash = hashlib.sha256() | |||
input_hash.update(inputs.encode(encoding='UTF-8')) | |||
input_hash = input_hash.hexdigest() | |||
writer = ConverterWriter() | |||
with open(args.output, 'w') as fout: | |||
writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) | |||
if __name__ == "__main__": | |||
main() |
@@ -8,9 +8,7 @@ file(GLOB_RECURSE SRCS src/impl/*.cpp src/include/*.h python/src/*.cpp python/sr | |||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMGB_WITH_IMPERATIVE=1") | |||
file(GLOB_RECURSE OPR_DECL_SRCS "${PROJECT_SOURCE_DIR}/src/**/*.oprdecl") | |||
file(GLOB_RECURSE PYTHON_SRCS python/${PACKAGE_NAME}/*.py) | |||
list(REMOVE_ITEM PYTHON_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/python/megengine/core/ops/_internal/generated_ops.py ${CMAKE_CURRENT_SOURCE_DIR}/python/megengine/core/ops/_internal/param_defs.py) | |||
file(GLOB_RECURSE ALL_HEADERS src/cpp/megbrain_pubapi.h | |||
${PROJECT_SOURCE_DIR}/src/core/include/* | |||
${PROJECT_SOURCE_DIR}/src/opr/include/* | |||
@@ -19,33 +17,8 @@ file(GLOB_RECURSE ALL_HEADERS src/cpp/megbrain_pubapi.h | |||
${PROJECT_SOURCE_DIR}/dnn/include/*) | |||
set(MEGENGINE_DIR ${CMAKE_CURRENT_BINARY_DIR}/python/) | |||
set(GEN_OPS_DIR ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/ops/_internal) | |||
file(MAKE_DIRECTORY ${GEN_OPS_DIR}) | |||
set(GEN_OPS_FILE ${GEN_OPS_DIR}/generated_ops.py) | |||
set(GEN_OP_PARAMS_FILE ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/ops/_internal/param_defs.py) | |||
set(GEN_OP_PARAMS_TEMPLATE ${CMAKE_CURRENT_SOURCE_DIR}/python/tools/ops.tpl.py) | |||
##################### generate python opr_param_defs.py ############## | |||
file(COPY ${PROJECT_SOURCE_DIR}/dnn/scripts/opr_param_defs.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) | |||
file(READ ${PROJECT_SOURCE_DIR}/tools/param_defs/mgb_opr_param_defs.py CONTENTS) | |||
file(APPEND ${CMAKE_CURRENT_BINARY_DIR}/opr_param_defs.py ${CONTENTS}) | |||
add_custom_command( | |||
OUTPUT ${GEN_OPS_FILE} | |||
COMMAND ${CMAKE_COMMAND} -E touch ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/${MODULE_NAME}.so ${GEN_OPS_FILE} ${GEN_OP_PARAMS_FILE} | |||
COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/python/${PACKAGE_NAME} ${MEGENGINE_DIR}/${PACKAGE_NAME} | |||
COMMAND ${CMAKE_COMMAND} -E remove -f ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/${MODULE_NAME}.so ${GEN_OPS_FILE} ${GEN_OP_PARAMS_FILE} | |||
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/python/tools/gen_ops.py ${OPR_DECL_SRCS} -o ${GEN_OPS_FILE} | |||
COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/python/test ${MEGENGINE_DIR}/${PACKAGE_NAME}/test | |||
COMMAND ${PYTHON_EXECUTABLE} ${PROJECT_SOURCE_DIR}/dnn/scripts/gen_param_defs.py -t py --imperative ${CMAKE_CURRENT_BINARY_DIR}/opr_param_defs.py ${GEN_OP_PARAMS_FILE} | |||
DEPENDS ${OPR_DECL_SRCS} ${PYTHON_SRCS} ${ALL_HEADERS} ${GEN_OP_PARAMS_TEMPLATE} | |||
VERBATIM | |||
) | |||
add_custom_target(gen_opr_py DEPENDS ${GEN_OPS_FILE}) | |||
##################### end of opdef generation ######################### | |||
add_subdirectory(tablegen) | |||
add_custom_target(_version_ld SOURCES ${MGE_VERSION_SCRIPT}) | |||
@@ -73,7 +46,7 @@ else() | |||
endif() | |||
endif() | |||
target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR}) | |||
target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR}) | |||
target_compile_definitions(${MODULE_NAME} PRIVATE MODULE_NAME=${MODULE_NAME}) | |||
target_compile_options(${MODULE_NAME} PRIVATE -Wno-unused-parameter) | |||
if(CXX_SUPPORT_WCLASS_MEMACCESS) | |||
@@ -87,7 +60,7 @@ if (APPLE OR MSVC OR WIN32) | |||
message(VERBOSE "overwriting SUFFIX at macos and windows before config by set_target_properties") | |||
pybind11_extension(${MODULE_NAME}) | |||
endif() | |||
add_dependencies(${MODULE_NAME} gen_opr_py _version_ld) | |||
add_dependencies(${MODULE_NAME} mgb_opdef _version_ld) | |||
if(MGE_WITH_TEST AND MGE_ENABLE_RTTI) | |||
add_subdirectory(test) | |||
@@ -19,7 +19,6 @@ from ..ops.builtin import ( | |||
IndexingMultiAxisVec, | |||
IndexingSetMultiAxisVec, | |||
OpDef, | |||
OprAttr, | |||
Reduce, | |||
Reshape, | |||
SetSubtensor, | |||
@@ -31,8 +30,6 @@ from ..tensor.function import Function | |||
from ..tensor.tensor import Tensor | |||
from ..tensor.tensor_wrapper import TensorWrapper | |||
_reduce_sum_param = Reduce(mode="SUM").to_c().param[0] | |||
@functools.singledispatch | |||
def builtin_op_get_backward_fn(op: OpDef, inputs, outputs, input_requires_grad): | |||
@@ -41,17 +38,18 @@ def builtin_op_get_backward_fn(op: OpDef, inputs, outputs, input_requires_grad): | |||
@builtin_op_get_backward_fn.register(OpDef) | |||
def _(op: OpDef, inputs, outputs, input_requires_grad): | |||
if isinstance(op, OprAttr): | |||
grad_fn = _oprAttr_grad_fn.get(op.type, None) | |||
if grad_fn is None: | |||
if op.type == Reduce.name and op.param[0] == _reduce_sum_param: | |||
grad_fn = reduce_sum_grad_fn | |||
else: | |||
grad_fn = default_grad_fn | |||
if isinstance(op, Reshape): | |||
grad_fn = reshape_grad_fn | |||
elif isinstance(op, Subtensor): | |||
grad_fn = subtensor_grad_fn | |||
elif isinstance(op, IndexingMultiAxisVec): | |||
grad_fn = indexingMultiAxisVec_grad_fn | |||
elif isinstance(op, Broadcast) or ( | |||
isinstance(op, Elemwise) and op.mode == Elemwise.Mode.ADD | |||
): | |||
grad_fn = elemwise_add_grad_fn | |||
elif isinstance(op, Reduce) and op.mode.name == "SUM": | |||
grad_fn = reduce_sum_grad_fn | |||
else: | |||
grad_fn = default_grad_fn | |||
return grad_fn(op, inputs, outputs, input_requires_grad) | |||
@@ -152,9 +150,7 @@ def reshape_grad_fn(op, inputs, outputs, input_requires_grad): | |||
# override for Subtensor | |||
def subtensor_grad_fn(op, inputs, outputs, input_requires_grad): | |||
grad_op = OprAttr() | |||
grad_op.type = SetSubtensor.name | |||
grad_op.param = op.param | |||
grad_op = SetSubtensor(op.items) | |||
input_shape = get_shape(inputs[0]) | |||
params = inputs[1:] | |||
@@ -175,9 +171,7 @@ def subtensor_grad_fn(op, inputs, outputs, input_requires_grad): | |||
# override for IndexingMultiAxisVec | |||
def indexingMultiAxisVec_grad_fn(op, inputs, outputs, input_requires_grad): | |||
grad_op = OprAttr() | |||
grad_op.type = IndexingSetMultiAxisVec.name | |||
grad_op.param = op.param | |||
grad_op = IndexingSetMultiAxisVec(op.items) | |||
input_shape = get_shape(inputs[0]) | |||
params = inputs[1:] | |||
@@ -209,10 +203,3 @@ def reduce_sum_grad_fn(op, inputs, outputs, input_requires_grad): | |||
return (broadcast_to(dy, input_shape) if input_requires_grad[0] else None,) | |||
return backward, [True] | |||
_oprAttr_grad_fn = { | |||
Reshape.name: reshape_grad_fn, | |||
Subtensor.name: subtensor_grad_fn, | |||
IndexingMultiAxisVec.name: indexingMultiAxisVec_grad_fn, | |||
} |
@@ -1,8 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
@@ -1,10 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .generated_ops import * | |||
from .misc_ops import * |
@@ -1,939 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import sys | |||
from functools import reduce | |||
from operator import or_ as _or_ | |||
from types import DynamicClassAttribute, MappingProxyType | |||
# try _collections first to reduce startup cost | |||
try: | |||
from _collections import OrderedDict | |||
except ImportError: | |||
from collections import OrderedDict | |||
__all__ = [ | |||
"EnumMeta", | |||
"Enum", | |||
"IntEnum", | |||
"Flag", | |||
"IntFlag", | |||
"auto", | |||
"unique", | |||
] | |||
def _is_descriptor(obj): | |||
"""Returns True if obj is a descriptor, False otherwise.""" | |||
return ( | |||
hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__") | |||
) | |||
def _is_dunder(name): | |||
"""Returns True if a __dunder__ name, False otherwise.""" | |||
return ( | |||
name[:2] == name[-2:] == "__" | |||
and name[2:3] != "_" | |||
and name[-3:-2] != "_" | |||
and len(name) > 4 | |||
) | |||
def _is_sunder(name): | |||
"""Returns True if a _sunder_ name, False otherwise.""" | |||
return ( | |||
name[0] == name[-1] == "_" | |||
and name[1:2] != "_" | |||
and name[-2:-1] != "_" | |||
and len(name) > 2 | |||
) | |||
def _make_class_unpicklable(cls): | |||
"""Make the given class un-picklable.""" | |||
def _break_on_call_reduce(self, proto): | |||
raise TypeError("%r cannot be pickled" % self) | |||
cls.__reduce_ex__ = _break_on_call_reduce | |||
cls.__module__ = "<unknown>" | |||
_auto_null = object() | |||
class auto: | |||
""" | |||
Instances are replaced with an appropriate value in Enum class suites. | |||
""" | |||
value = _auto_null | |||
class _EnumDict(dict): | |||
""" | |||
Track enum member order and ensure member names are not reused. | |||
EnumMeta will use the names found in self._member_names as the | |||
enumeration member names. | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
self._member_names = [] | |||
self._last_values = [] | |||
def __setitem__(self, key, value): | |||
""" | |||
Changes anything not dundered or not a descriptor. | |||
If an enum member name is used twice, an error is raised; duplicate | |||
values are not checked for. | |||
Single underscore (sunder) names are reserved. | |||
""" | |||
if _is_sunder(key): | |||
if key not in ( | |||
"_order_", | |||
"_create_pseudo_member_", | |||
"_generate_next_value_", | |||
"_missing_", | |||
): | |||
raise ValueError("_names_ are reserved for future Enum use") | |||
if key == "_generate_next_value_": | |||
setattr(self, "_generate_next_value", value) | |||
elif _is_dunder(key): | |||
if key == "__order__": | |||
key = "_order_" | |||
elif key in self._member_names: | |||
# descriptor overwriting an enum? | |||
raise TypeError("Attempted to reuse key: %r" % key) | |||
elif not _is_descriptor(value): | |||
if key in self: | |||
# enum overwriting a descriptor? | |||
raise TypeError("%r already defined as: %r" % (key, self[key])) | |||
if isinstance(value, auto): | |||
if value.value == _auto_null: | |||
value.value = self._generate_next_value( | |||
key, 1, len(self._member_names), self._last_values[:] | |||
) | |||
value = value.value | |||
self._member_names.append(key) | |||
self._last_values.append(value) | |||
super().__setitem__(key, value) | |||
# Dummy value for Enum as EnumMeta explicitly checks for it, but of course | |||
# until EnumMeta finishes running the first time the Enum class doesn't exist. | |||
# This is also why there are checks in EnumMeta like `if Enum is not None` | |||
Enum = None | |||
class EnumMeta(type): | |||
"""Metaclass for Enum""" | |||
@classmethod | |||
def __prepare__(metacls, cls, bases): | |||
# create the namespace dict | |||
enum_dict = _EnumDict() | |||
# inherit previous flags and _generate_next_value_ function | |||
member_type, first_enum = metacls._get_mixins_(bases) | |||
if first_enum is not None: | |||
enum_dict["_generate_next_value_"] = getattr( | |||
first_enum, "_generate_next_value_", None | |||
) | |||
return enum_dict | |||
def __new__(metacls, cls, bases, classdict): | |||
# an Enum class is final once enumeration items have been defined; it | |||
# cannot be mixed with other types (int, float, etc.) if it has an | |||
# inherited __new__ unless a new __new__ is defined (or the resulting | |||
# class will fail). | |||
member_type, first_enum = metacls._get_mixins_(bases) | |||
__new__, save_new, use_args = metacls._find_new_( | |||
classdict, member_type, first_enum | |||
) | |||
# save enum items into separate mapping so they don't get baked into | |||
# the new class | |||
enum_members = {k: classdict[k] for k in classdict._member_names} | |||
for name in classdict._member_names: | |||
del classdict[name] | |||
# adjust the sunders | |||
_order_ = classdict.pop("_order_", None) | |||
# check for illegal enum names (any others?) | |||
invalid_names = set(enum_members) & { | |||
"mro", | |||
} | |||
if invalid_names: | |||
raise ValueError( | |||
"Invalid enum member name: {0}".format(",".join(invalid_names)) | |||
) | |||
# create a default docstring if one has not been provided | |||
if "__doc__" not in classdict: | |||
classdict["__doc__"] = "An enumeration." | |||
# create our new Enum type | |||
enum_class = super().__new__(metacls, cls, bases, classdict) | |||
enum_class._member_names_ = [] # names in definition order | |||
enum_class._member_map_ = OrderedDict() # name->value map | |||
enum_class._member_type_ = member_type | |||
# save attributes from super classes so we know if we can take | |||
# the shortcut of storing members in the class dict | |||
base_attributes = {a for b in enum_class.mro() for a in b.__dict__} | |||
# Reverse value->name map for hashable values. | |||
enum_class._value2member_map_ = {} | |||
# If a custom type is mixed into the Enum, and it does not know how | |||
# to pickle itself, pickle.dumps will succeed but pickle.loads will | |||
# fail. Rather than have the error show up later and possibly far | |||
# from the source, sabotage the pickle protocol for this class so | |||
# that pickle.dumps also fails. | |||
# | |||
# However, if the new class implements its own __reduce_ex__, do not | |||
# sabotage -- it's on them to make sure it works correctly. We use | |||
# __reduce_ex__ instead of any of the others as it is preferred by | |||
# pickle over __reduce__, and it handles all pickle protocols. | |||
if "__reduce_ex__" not in classdict: | |||
if member_type is not object: | |||
methods = ( | |||
"__getnewargs_ex__", | |||
"__getnewargs__", | |||
"__reduce_ex__", | |||
"__reduce__", | |||
) | |||
if not any(m in member_type.__dict__ for m in methods): | |||
_make_class_unpicklable(enum_class) | |||
# instantiate them, checking for duplicates as we go | |||
# we instantiate first instead of checking for duplicates first in case | |||
# a custom __new__ is doing something funky with the values -- such as | |||
# auto-numbering ;) | |||
for member_name in classdict._member_names: | |||
value = enum_members[member_name] | |||
if not isinstance(value, tuple): | |||
args = (value,) | |||
else: | |||
args = value | |||
if member_type is tuple: # special case for tuple enums | |||
args = (args,) # wrap it one more time | |||
if not use_args: | |||
enum_member = __new__(enum_class) | |||
if not hasattr(enum_member, "_value_"): | |||
enum_member._value_ = value | |||
else: | |||
enum_member = __new__(enum_class, *args) | |||
if not hasattr(enum_member, "_value_"): | |||
if member_type is object: | |||
enum_member._value_ = value | |||
else: | |||
enum_member._value_ = member_type(*args) | |||
value = enum_member._value_ | |||
enum_member._name_ = member_name | |||
enum_member.__objclass__ = enum_class | |||
enum_member.__init__(*args) | |||
# If another member with the same value was already defined, the | |||
# new member becomes an alias to the existing one. | |||
for name, canonical_member in enum_class._member_map_.items(): | |||
if canonical_member._value_ == enum_member._value_: | |||
enum_member = canonical_member | |||
break | |||
else: | |||
# Aliases don't appear in member names (only in __members__). | |||
enum_class._member_names_.append(member_name) | |||
# performance boost for any member that would not shadow | |||
# a DynamicClassAttribute | |||
if member_name not in base_attributes: | |||
setattr(enum_class, member_name, enum_member) | |||
# now add to _member_map_ | |||
enum_class._member_map_[member_name] = enum_member | |||
try: | |||
# This may fail if value is not hashable. We can't add the value | |||
# to the map, and by-value lookups for this value will be | |||
# linear. | |||
enum_class._value2member_map_[value] = enum_member | |||
except TypeError: | |||
pass | |||
# double check that repr and friends are not the mixin's or various | |||
# things break (such as pickle) | |||
for name in ("__repr__", "__str__", "__format__", "__reduce_ex__"): | |||
class_method = getattr(enum_class, name) | |||
obj_method = getattr(member_type, name, None) | |||
enum_method = getattr(first_enum, name, None) | |||
if obj_method is not None and obj_method is class_method: | |||
setattr(enum_class, name, enum_method) | |||
# replace any other __new__ with our own (as long as Enum is not None, | |||
# anyway) -- again, this is to support pickle | |||
if Enum is not None: | |||
# if the user defined their own __new__, save it before it gets | |||
# clobbered in case they subclass later | |||
if save_new: | |||
enum_class.__new_member__ = __new__ | |||
enum_class.__new__ = Enum.__new__ | |||
# py3 support for definition order (helps keep py2/py3 code in sync) | |||
if _order_ is not None: | |||
if isinstance(_order_, str): | |||
_order_ = _order_.replace(",", " ").split() | |||
if _order_ != enum_class._member_names_: | |||
raise TypeError("member order does not match _order_") | |||
return enum_class | |||
def __bool__(self): | |||
""" | |||
classes/types should always be True. | |||
""" | |||
return True | |||
def __call__( | |||
cls, value, names=None, *, module=None, qualname=None, type=None, start=1 | |||
): | |||
""" | |||
Either returns an existing member, or creates a new enum class. | |||
This method is used both when an enum class is given a value to match | |||
to an enumeration member (i.e. Color(3)) and for the functional API | |||
(i.e. Color = Enum('Color', names='RED GREEN BLUE')). | |||
When used for the functional API: | |||
`value` will be the name of the new class. | |||
`names` should be either a string of white-space/comma delimited names | |||
(values will start at `start`), or an iterator/mapping of name, value pairs. | |||
`module` should be set to the module this class is being created in; | |||
if it is not set, an attempt to find that module will be made, but if | |||
it fails the class will not be picklable. | |||
`qualname` should be set to the actual location this class can be found | |||
at in its module; by default it is set to the global scope. If this is | |||
not correct, unpickling will fail in some circumstances. | |||
`type`, if set, will be mixed in as the first base class. | |||
""" | |||
if names is None: # simple value lookup | |||
return cls.__new__(cls, value) | |||
# otherwise, functional API: we're creating a new Enum type | |||
return cls._create_( | |||
value, names, module=module, qualname=qualname, type=type, start=start | |||
) | |||
def __contains__(cls, member): | |||
return isinstance(member, cls) and member._name_ in cls._member_map_ | |||
def __delattr__(cls, attr): | |||
# nicer error message when someone tries to delete an attribute | |||
# (see issue19025). | |||
if attr in cls._member_map_: | |||
raise AttributeError("%s: cannot delete Enum member." % cls.__name__) | |||
super().__delattr__(attr) | |||
def __dir__(self): | |||
return [ | |||
"__class__", | |||
"__doc__", | |||
"__members__", | |||
"__module__", | |||
] + self._member_names_ | |||
def __getattr__(cls, name): | |||
""" | |||
Return the enum member matching `name` | |||
We use __getattr__ instead of descriptors or inserting into the enum | |||
class' __dict__ in order to support `name` and `value` being both | |||
properties for enum members (which live in the class' __dict__) and | |||
enum members themselves. | |||
""" | |||
if _is_dunder(name): | |||
raise AttributeError(name) | |||
try: | |||
return cls._member_map_[name] | |||
except KeyError: | |||
raise AttributeError(name) from None | |||
def __getitem__(cls, name): | |||
return cls._member_map_[name] | |||
def __iter__(cls): | |||
return (cls._member_map_[name] for name in cls._member_names_) | |||
def __len__(cls): | |||
return len(cls._member_names_) | |||
@property | |||
def __members__(cls): | |||
""" | |||
Returns a mapping of member name->value. | |||
This mapping lists all enum members, including aliases. Note that this | |||
is a read-only view of the internal mapping. | |||
""" | |||
return MappingProxyType(cls._member_map_) | |||
def __repr__(cls): | |||
return "<enum %r>" % cls.__name__ | |||
def __reversed__(cls): | |||
return (cls._member_map_[name] for name in reversed(cls._member_names_)) | |||
def __setattr__(cls, name, value): | |||
""" | |||
Block attempts to reassign Enum members. | |||
A simple assignment to the class namespace only changes one of the | |||
several possible ways to get an Enum member from the Enum class, | |||
resulting in an inconsistent Enumeration. | |||
""" | |||
member_map = cls.__dict__.get("_member_map_", {}) | |||
if name in member_map: | |||
raise AttributeError("Cannot reassign members.") | |||
super().__setattr__(name, value) | |||
def _create_( | |||
cls, class_name, names=None, *, module=None, qualname=None, type=None, start=1 | |||
): | |||
""" | |||
Convenience method to create a new Enum class. | |||
`names` can be: | |||
* A string containing member names, separated either with spaces or | |||
commas. Values are incremented by 1 from `start`. | |||
* An iterable of member names. Values are incremented by 1 from `start`. | |||
* An iterable of (member name, value) pairs. | |||
* A mapping of member name -> value pairs. | |||
""" | |||
metacls = cls.__class__ | |||
bases = (cls,) if type is None else (type, cls) | |||
_, first_enum = cls._get_mixins_(bases) | |||
classdict = metacls.__prepare__(class_name, bases) | |||
# special processing needed for names? | |||
if isinstance(names, str): | |||
names = names.replace(",", " ").split() | |||
if isinstance(names, (tuple, list)) and names and isinstance(names[0], str): | |||
original_names, names = names, [] | |||
last_values = [] | |||
for count, name in enumerate(original_names): | |||
value = first_enum._generate_next_value_( | |||
name, start, count, last_values[:] | |||
) | |||
last_values.append(value) | |||
names.append((name, value)) | |||
# Here, names is either an iterable of (name, value) or a mapping. | |||
for item in names: | |||
if isinstance(item, str): | |||
member_name, member_value = item, names[item] | |||
else: | |||
member_name, member_value = item | |||
classdict[member_name] = member_value | |||
enum_class = metacls.__new__(metacls, class_name, bases, classdict) | |||
# TODO: replace the frame hack if a blessed way to know the calling | |||
# module is ever developed | |||
if module is None: | |||
try: | |||
module = sys._getframe(2).f_globals["__name__"] | |||
except (AttributeError, ValueError) as exc: | |||
pass | |||
if module is None: | |||
_make_class_unpicklable(enum_class) | |||
else: | |||
enum_class.__module__ = module | |||
if qualname is not None: | |||
enum_class.__qualname__ = qualname | |||
return enum_class | |||
@staticmethod | |||
def _get_mixins_(bases): | |||
""" | |||
Returns the type for creating enum members, and the first inherited | |||
enum class. | |||
bases: the tuple of bases that was given to __new__ | |||
""" | |||
if not bases: | |||
return object, Enum | |||
# double check that we are not subclassing a class with existing | |||
# enumeration members; while we're at it, see if any other data | |||
# type has been mixed in so we can use the correct __new__ | |||
member_type = first_enum = None | |||
for base in bases: | |||
if base is not Enum and issubclass(base, Enum) and base._member_names_: | |||
raise TypeError("Cannot extend enumerations") | |||
# base is now the last base in bases | |||
if not issubclass(base, Enum): | |||
raise TypeError( | |||
"new enumerations must be created as " | |||
"`ClassName([mixin_type,] enum_type)`" | |||
) | |||
# get correct mix-in type (either mix-in type of Enum subclass, or | |||
# first base if last base is Enum) | |||
if not issubclass(bases[0], Enum): | |||
member_type = bases[0] # first data type | |||
first_enum = bases[-1] # enum type | |||
else: | |||
for base in bases[0].__mro__: | |||
# most common: (IntEnum, int, Enum, object) | |||
# possible: (<Enum 'AutoIntEnum'>, <Enum 'IntEnum'>, | |||
# <class 'int'>, <Enum 'Enum'>, | |||
# <class 'object'>) | |||
if issubclass(base, Enum): | |||
if first_enum is None: | |||
first_enum = base | |||
else: | |||
if member_type is None: | |||
member_type = base | |||
return member_type, first_enum | |||
@staticmethod | |||
def _find_new_(classdict, member_type, first_enum): | |||
""" | |||
Returns the __new__ to be used for creating the enum members. | |||
classdict: the class dictionary given to __new__ | |||
member_type: the data type whose __new__ will be used by default | |||
first_enum: enumeration to check for an overriding __new__ | |||
""" | |||
# now find the correct __new__, checking to see of one was defined | |||
# by the user; also check earlier enum classes in case a __new__ was | |||
# saved as __new_member__ | |||
__new__ = classdict.get("__new__", None) | |||
# should __new__ be saved as __new_member__ later? | |||
save_new = __new__ is not None | |||
if __new__ is None: | |||
# check all possibles for __new_member__ before falling back to | |||
# __new__ | |||
for method in ("__new_member__", "__new__"): | |||
for possible in (member_type, first_enum): | |||
target = getattr(possible, method, None) | |||
if target not in { | |||
None, | |||
None.__new__, | |||
object.__new__, | |||
Enum.__new__, | |||
}: | |||
__new__ = target | |||
break | |||
if __new__ is not None: | |||
break | |||
else: | |||
__new__ = object.__new__ | |||
# if a non-object.__new__ is used then whatever value/tuple was | |||
# assigned to the enum member name will be passed to __new__ and to the | |||
# new enum member's __init__ | |||
if __new__ is object.__new__: | |||
use_args = False | |||
else: | |||
use_args = True | |||
return __new__, save_new, use_args | |||
class Enum(metaclass=EnumMeta): | |||
""" | |||
Generic enumeration. | |||
Derive from this class to define new enumerations. | |||
""" | |||
def __new__(cls, value): | |||
# all enum instances are actually created during class construction | |||
# without calling this method; this method is called by the metaclass' | |||
# __call__ (i.e. Color(3) ), and by pickle | |||
if type(value) is cls: | |||
# For lookups like Color(Color.RED) | |||
return value | |||
# by-value search for a matching enum member | |||
# see if it's in the reverse mapping (for hashable values) | |||
try: | |||
if value in cls._value2member_map_: | |||
return cls._value2member_map_[value] | |||
except TypeError: | |||
# not there, now do long search -- O(n) behavior | |||
for member in cls._member_map_.values(): | |||
if member._value_ == value: | |||
return member | |||
# still not found -- try _missing_ hook | |||
return cls._missing_(value) | |||
def _generate_next_value_(name, start, count, last_values): | |||
for last_value in reversed(last_values): | |||
try: | |||
return last_value + 1 | |||
except TypeError: | |||
pass | |||
else: | |||
return start | |||
@classmethod | |||
def _missing_(cls, value): | |||
raise ValueError("%r is not a valid %s" % (value, cls.__name__)) | |||
def __repr__(self): | |||
return "<%s.%s: %r>" % (self.__class__.__name__, self._name_, self._value_) | |||
def __str__(self): | |||
return "%s.%s" % (self.__class__.__name__, self._name_) | |||
def __dir__(self): | |||
added_behavior = [ | |||
m | |||
for cls in self.__class__.mro() | |||
for m in cls.__dict__ | |||
if m[0] != "_" and m not in self._member_map_ | |||
] | |||
return ["__class__", "__doc__", "__module__"] + added_behavior | |||
def __format__(self, format_spec): | |||
# mixed-in Enums should use the mixed-in type's __format__, otherwise | |||
# we can get strange results with the Enum name showing up instead of | |||
# the value | |||
# pure Enum branch | |||
if self._member_type_ is object: | |||
cls = str | |||
val = str(self) | |||
# mix-in branch | |||
else: | |||
cls = self._member_type_ | |||
val = self._value_ | |||
return cls.__format__(val, format_spec) | |||
def __hash__(self): | |||
return hash(self._name_) | |||
def __reduce_ex__(self, proto): | |||
return self.__class__, (self._value_,) | |||
# DynamicClassAttribute is used to provide access to the `name` and | |||
# `value` properties of enum members while keeping some measure of | |||
# protection from modification, while still allowing for an enumeration | |||
# to have members named `name` and `value`. This works because enumeration | |||
# members are not set directly on the enum class -- __getattr__ is | |||
# used to look them up. | |||
@DynamicClassAttribute | |||
def name(self): | |||
"""The name of the Enum member.""" | |||
return self._name_ | |||
@DynamicClassAttribute | |||
def value(self): | |||
"""The value of the Enum member.""" | |||
return self._value_ | |||
@classmethod | |||
def _convert(cls, name, module, filter, source=None): | |||
""" | |||
Create a new Enum subclass that replaces a collection of global constants | |||
""" | |||
# convert all constants from source (or module) that pass filter() to | |||
# a new Enum called name, and export the enum and its members back to | |||
# module; | |||
# also, replace the __reduce_ex__ method so unpickling works in | |||
# previous Python versions | |||
module_globals = vars(sys.modules[module]) | |||
if source: | |||
source = vars(source) | |||
else: | |||
source = module_globals | |||
# We use an OrderedDict of sorted source keys so that the | |||
# _value2member_map is populated in the same order every time | |||
# for a consistent reverse mapping of number to name when there | |||
# are multiple names for the same number rather than varying | |||
# between runs due to hash randomization of the module dictionary. | |||
members = [(name, source[name]) for name in source.keys() if filter(name)] | |||
try: | |||
# sort by value | |||
members.sort(key=lambda t: (t[1], t[0])) | |||
except TypeError: | |||
# unless some values aren't comparable, in which case sort by name | |||
members.sort(key=lambda t: t[0]) | |||
cls = cls(name, members, module=module) | |||
cls.__reduce_ex__ = _reduce_ex_by_name | |||
module_globals.update(cls.__members__) | |||
module_globals[name] = cls | |||
return cls | |||
class IntEnum(int, Enum): | |||
"""Enum where members are also (and must be) ints""" | |||
def _reduce_ex_by_name(self, proto): | |||
return self.name | |||
class Flag(Enum): | |||
"""Support for flags""" | |||
def _generate_next_value_(name, start, count, last_values): | |||
""" | |||
Generate the next value when not given. | |||
name: the name of the member | |||
start: the initital start value or None | |||
count: the number of existing members | |||
last_value: the last value assigned or None | |||
""" | |||
if not count: | |||
return start if start is not None else 1 | |||
for last_value in reversed(last_values): | |||
try: | |||
high_bit = _high_bit(last_value) | |||
break | |||
except Exception: | |||
raise TypeError("Invalid Flag value: %r" % last_value) from None | |||
return 2 ** (high_bit + 1) | |||
@classmethod | |||
def _missing_(cls, value): | |||
original_value = value | |||
if value < 0: | |||
value = ~value | |||
possible_member = cls._create_pseudo_member_(value) | |||
if original_value < 0: | |||
possible_member = ~possible_member | |||
return possible_member | |||
@classmethod | |||
def _create_pseudo_member_(cls, value): | |||
""" | |||
Create a composite member iff value contains only members. | |||
""" | |||
pseudo_member = cls._value2member_map_.get(value, None) | |||
if pseudo_member is None: | |||
# verify all bits are accounted for | |||
_, extra_flags = _decompose(cls, value) | |||
if extra_flags: | |||
raise ValueError("%r is not a valid %s" % (value, cls.__name__)) | |||
# construct a singleton enum pseudo-member | |||
pseudo_member = object.__new__(cls) | |||
pseudo_member._name_ = None | |||
pseudo_member._value_ = value | |||
# use setdefault in case another thread already created a composite | |||
# with this value | |||
pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member) | |||
return pseudo_member | |||
def __contains__(self, other): | |||
if not isinstance(other, self.__class__): | |||
return NotImplemented | |||
return other._value_ & self._value_ == other._value_ | |||
def __repr__(self): | |||
cls = self.__class__ | |||
if self._name_ is not None: | |||
return "<%s.%s: %r>" % (cls.__name__, self._name_, self._value_) | |||
members, uncovered = _decompose(cls, self._value_) | |||
return "<%s.%s: %r>" % ( | |||
cls.__name__, | |||
"|".join([str(m._name_ or m._value_) for m in members]), | |||
self._value_, | |||
) | |||
def __str__(self): | |||
cls = self.__class__ | |||
if self._name_ is not None: | |||
return "%s.%s" % (cls.__name__, self._name_) | |||
members, uncovered = _decompose(cls, self._value_) | |||
if len(members) == 1 and members[0]._name_ is None: | |||
return "%s.%r" % (cls.__name__, members[0]._value_) | |||
else: | |||
return "%s.%s" % ( | |||
cls.__name__, | |||
"|".join([str(m._name_ or m._value_) for m in members]), | |||
) | |||
def __bool__(self): | |||
return bool(self._value_) | |||
def __or__(self, other): | |||
if not isinstance(other, self.__class__): | |||
return NotImplemented | |||
return self.__class__(self._value_ | other._value_) | |||
def __and__(self, other): | |||
if not isinstance(other, self.__class__): | |||
return NotImplemented | |||
return self.__class__(self._value_ & other._value_) | |||
def __xor__(self, other): | |||
if not isinstance(other, self.__class__): | |||
return NotImplemented | |||
return self.__class__(self._value_ ^ other._value_) | |||
def __invert__(self): | |||
members, uncovered = _decompose(self.__class__, self._value_) | |||
inverted_members = [ | |||
m | |||
for m in self.__class__ | |||
if m not in members and not m._value_ & self._value_ | |||
] | |||
inverted = reduce(_or_, inverted_members, self.__class__(0)) | |||
return self.__class__(inverted) | |||
class IntFlag(int, Flag): | |||
"""Support for integer-based Flags""" | |||
@classmethod | |||
def _missing_(cls, value): | |||
if not isinstance(value, int): | |||
raise ValueError("%r is not a valid %s" % (value, cls.__name__)) | |||
new_member = cls._create_pseudo_member_(value) | |||
return new_member | |||
@classmethod | |||
def _create_pseudo_member_(cls, value): | |||
pseudo_member = cls._value2member_map_.get(value, None) | |||
if pseudo_member is None: | |||
need_to_create = [value] | |||
# get unaccounted for bits | |||
_, extra_flags = _decompose(cls, value) | |||
# timer = 10 | |||
while extra_flags: | |||
# timer -= 1 | |||
bit = _high_bit(extra_flags) | |||
flag_value = 2 ** bit | |||
if ( | |||
flag_value not in cls._value2member_map_ | |||
and flag_value not in need_to_create | |||
): | |||
need_to_create.append(flag_value) | |||
if extra_flags == -flag_value: | |||
extra_flags = 0 | |||
else: | |||
extra_flags ^= flag_value | |||
for value in reversed(need_to_create): | |||
# construct singleton pseudo-members | |||
pseudo_member = int.__new__(cls, value) | |||
pseudo_member._name_ = None | |||
pseudo_member._value_ = value | |||
# use setdefault in case another thread already created a composite | |||
# with this value | |||
pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member) | |||
return pseudo_member | |||
def __or__(self, other): | |||
if not isinstance(other, (self.__class__, int)): | |||
return NotImplemented | |||
result = self.__class__(self._value_ | self.__class__(other)._value_) | |||
return result | |||
def __and__(self, other): | |||
if not isinstance(other, (self.__class__, int)): | |||
return NotImplemented | |||
return self.__class__(self._value_ & self.__class__(other)._value_) | |||
def __xor__(self, other): | |||
if not isinstance(other, (self.__class__, int)): | |||
return NotImplemented | |||
return self.__class__(self._value_ ^ self.__class__(other)._value_) | |||
__ror__ = __or__ | |||
__rand__ = __and__ | |||
__rxor__ = __xor__ | |||
def __invert__(self): | |||
result = self.__class__(~self._value_) | |||
return result | |||
def _high_bit(value): | |||
"""returns index of highest bit, or -1 if value is zero or negative""" | |||
return value.bit_length() - 1 | |||
def unique(enumeration): | |||
"""Class decorator for enumerations ensuring unique member values.""" | |||
duplicates = [] | |||
for name, member in enumeration.__members__.items(): | |||
if name != member.name: | |||
duplicates.append((name, member.name)) | |||
if duplicates: | |||
alias_details = ", ".join( | |||
["%s -> %s" % (alias, name) for (alias, name) in duplicates] | |||
) | |||
raise ValueError( | |||
"duplicate values found in %r: %s" % (enumeration, alias_details) | |||
) | |||
return enumeration | |||
def _decompose(flag, value): | |||
"""Extract all members from the value.""" | |||
# _decompose is only called if the value is not named | |||
not_covered = value | |||
negative = value < 0 | |||
# issue29167: wrap accesses to _value2member_map_ in a list to avoid race | |||
# conditions between iterating over it and having more psuedo- | |||
# members added to it | |||
if negative: | |||
# only check for named flags | |||
flags_to_check = [ | |||
(m, v) | |||
for v, m in list(flag._value2member_map_.items()) | |||
if m.name is not None | |||
] | |||
else: | |||
# check for named flags and powers-of-two flags | |||
flags_to_check = [ | |||
(m, v) | |||
for v, m in list(flag._value2member_map_.items()) | |||
if m.name is not None or _power_of_two(v) | |||
] | |||
members = [] | |||
for member, member_value in flags_to_check: | |||
if member_value and member_value & value == member_value: | |||
members.append(member) | |||
not_covered &= ~member_value | |||
if not members and value in flag._value2member_map_: | |||
members.append(flag._value2member_map_[value]) | |||
members.sort(key=lambda m: m._value_, reverse=True) | |||
if len(members) > 1 and members[0].value == value: | |||
# we have the breakdown, don't need the value member itself | |||
members.pop(0) | |||
return members, not_covered | |||
def _power_of_two(value): | |||
if value < 1: | |||
return False | |||
return value == 2 ** _high_bit(value) |
@@ -1,94 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import warnings | |||
from ..._imperative_rt.ops import OprAttr | |||
from . import param_defs | |||
def make_param(param, ptype, kwargs): | |||
if param is not None: | |||
if isinstance(param, ptype): | |||
return param | |||
param = [param] | |||
assert len(param) == len( | |||
ptype.__slots__ | |||
), "{} needs {} params, but {} are provided".format( | |||
ptype, len(ptype.__slots__), len(param) | |||
) | |||
return ptype(*param) | |||
ckw = {} | |||
for i in ptype.__slots__: | |||
val = kwargs.pop(i, ckw) | |||
if val is not ckw: | |||
ckw[i] = val | |||
return ptype(**ckw) | |||
class PodOpVisitor: | |||
__name2subclass = {} | |||
__c = None | |||
name = None | |||
param_names = [] | |||
config = None | |||
def __init__(self, config, **params): | |||
self.config = config | |||
assert set(params) == set(self.param_names) | |||
self.__dict__.update(params) | |||
def __init_subclass__(cls, **kwargs): | |||
super().__init_subclass__(**kwargs) # python 3.5 does not have this | |||
name = cls.name | |||
if name in cls.__name2subclass: | |||
if not issubclass(cls, cls.__name2subclass[name]): | |||
warnings.warn("Multiple subclasses for bultin op: %s" % name) | |||
cls.__name2subclass[name] = cls | |||
def to_c(self): | |||
if self.__c: | |||
return self.__c | |||
op = OprAttr() | |||
op.type = self.name | |||
if self.config is not None: | |||
op.config = self.config | |||
# first 4 bytes is TAG, has to remove them currently | |||
op.param = b"".join(self.__dict__[k].serialize()[4:] for k in self.param_names) | |||
self.__c = op | |||
return op | |||
def __eq__(self, rhs): | |||
return self.to_c() == rhs.to_c() | |||
def __repr__(self): | |||
name = self.__class__.__name__ | |||
if self.__c: | |||
return "{}(<binary data>)".format(name) | |||
kwargs = {} | |||
for i in self.param_names: | |||
p = self.__dict__[i] | |||
if isinstance(p, param_defs._ParamDefBase): | |||
for k in p.__slots__: | |||
v = getattr(p, k) | |||
if isinstance(v, param_defs._EnumBase): | |||
v = v.name | |||
kwargs[k] = repr(v) | |||
else: | |||
kwargs[i] = repr(p) | |||
if self.config: | |||
if len(self.config.comp_node_arr) == 1: | |||
kwargs["device"] = "'%s'" % self.config.comp_node | |||
return "{}({})".format( | |||
name, ", ".join("{}={}".format(k, v) for k, v in kwargs.items()) | |||
) |
@@ -1,194 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
import ctypes | |||
from ..._imperative_rt import OperatorNodeConfig as Config | |||
from . import param_defs | |||
from .helper import PodOpVisitor, make_param | |||
__all__ = ["ConvolutionBackwardData", "Dimshuffle", "Reshape", "AxisAddRemove"] | |||
class TensorShape: | |||
MAX_NDIM = 7 | |||
class ConvolutionBackwardData(PodOpVisitor): | |||
param_names = ( | |||
"param", | |||
"execution_polity", | |||
) | |||
name = "ConvolutionBackwardDataV1" | |||
def __init__( | |||
self, | |||
*, | |||
param=None, | |||
execution_polity=None, | |||
name=None, | |||
comp_node=None, | |||
config=None, | |||
dtype=None, | |||
**kwargs | |||
): | |||
config = config or Config() | |||
if name: | |||
config.name = name | |||
if comp_node: | |||
config.comp_node = comp_node | |||
if dtype: | |||
config.dtype = dtype | |||
self.config = config | |||
self.param = make_param(param, param_defs.Convolution, kwargs) | |||
self.execution_polity = make_param( | |||
execution_polity, param_defs.ExecutionPolicy, kwargs | |||
) | |||
assert not kwargs, "extra kwargs: {}".format(kwargs) | |||
class Dimshuffle(PodOpVisitor): | |||
name = "Dimshuffle" | |||
param_names = ("pattern",) | |||
class Pattern(ctypes.Structure): | |||
Pattern_Array = ctypes.c_int32 * TensorShape.MAX_NDIM | |||
_fields_ = [ | |||
("length", ctypes.c_uint32), | |||
("pattern", Pattern_Array), | |||
("ndim", ctypes.c_uint32), | |||
] | |||
def serialize(self): | |||
return bytes(ctypes.c_uint32(0)) + bytes(self) | |||
def __init__(self, pattern, ndim=0): | |||
assert isinstance(pattern, collections.abc.Iterable) | |||
assert len(pattern) <= TensorShape.MAX_NDIM | |||
pattern_array = Dimshuffle.Pattern.Pattern_Array() | |||
for idx, v in enumerate(pattern): | |||
pattern_array[idx] = ctypes.c_int32(-1 if v == "x" else int(v)) | |||
self.pattern = Dimshuffle.Pattern(len(pattern), pattern_array, ndim) | |||
class Reshape(PodOpVisitor): | |||
name = "ReshapeV1" | |||
param_names = ("unspec_axis",) | |||
def __init__(self, unspec_axis=None): | |||
if unspec_axis is None: | |||
self.unspec_axis = param_defs.OptionalAxisV1() | |||
else: | |||
self.unspec_axis = param_defs.OptionalAxisV1(unspec_axis) | |||
class AxisNum(ctypes.Structure): | |||
_fields_ = [ | |||
("m_num", ctypes.c_int), | |||
] | |||
class AxisDesc(ctypes.Structure): | |||
class Method(ctypes.c_int): | |||
ADD_1 = 0 | |||
REMOVE = 1 | |||
_fields_ = [ | |||
("method", Method), | |||
("axis", AxisNum), | |||
] | |||
@classmethod | |||
def make_add(cls, axis): | |||
return cls(cls.Method.ADD_1, AxisNum(axis)) | |||
@classmethod | |||
def make_remove(cls, axis): | |||
return cls(cls.Method.REMOVE, AxisNum(axis)) | |||
class AxisAddRemove(PodOpVisitor): | |||
name = "AxisAddRemove" | |||
param_names = ("param",) | |||
AxisDesc = AxisDesc | |||
class Param(ctypes.Structure): | |||
MAX_DESC_SIZE = TensorShape.MAX_NDIM * 2 | |||
_fields_ = [("nr_desc", ctypes.c_uint32), ("desc", AxisDesc * MAX_DESC_SIZE)] | |||
def __init__(self, *args): | |||
super().__init__() | |||
self.nr_desc = len(args) | |||
for i, a in enumerate(args): | |||
self.desc[i] = a | |||
def serialize(self): | |||
return bytes(ctypes.c_uint32(0)) + bytes(self) | |||
def __init__(self, param): | |||
assert isinstance(param, self.Param) | |||
self.param = param | |||
del AxisDesc | |||
class IndexingOpBase(PodOpVisitor): | |||
param_names = ("index_desc",) | |||
class IndexDescMaskDump(ctypes.Structure): | |||
class Item(ctypes.Structure): | |||
_fields_ = [ | |||
("axis", ctypes.c_int8), | |||
("begin", ctypes.c_bool), | |||
("end", ctypes.c_bool), | |||
("step", ctypes.c_bool), | |||
("idx", ctypes.c_bool), | |||
] | |||
Item_Array = Item * TensorShape.MAX_NDIM | |||
_fields_ = [("nr_item", ctypes.c_uint8), ("items", Item_Array)] | |||
def serialize(self): | |||
return bytes(ctypes.c_uint32(0)) + bytes(self) | |||
def __init__(self, items): | |||
nr_item = len(items) | |||
assert nr_item <= TensorShape.MAX_NDIM | |||
item_array = IndexingOpBase.IndexDescMaskDump.Item_Array() | |||
for idx, item in enumerate(items): | |||
assert isinstance(item, (tuple, list)) and len(item) == 5 | |||
item_array[idx] = IndexingOpBase.IndexDescMaskDump.Item(*item) | |||
self.index_desc = IndexingOpBase.IndexDescMaskDump(nr_item, item_array) | |||
def _gen_indexing_defs(*names): | |||
for name in names: | |||
globals()[name] = type(name, (IndexingOpBase,), dict(name=name)) | |||
__all__.append(name) | |||
_gen_indexing_defs( | |||
"Subtensor", | |||
"SetSubtensor", | |||
"IncrSubtensor", | |||
"IndexingMultiAxisVec", | |||
"IndexingSetMultiAxisVec", | |||
"IndexingIncrMultiAxisVec", | |||
"MeshIndexing", | |||
"IncrMeshIndexing", | |||
"SetMeshIndexing", | |||
"BatchedMeshIndexing", | |||
"BatchedIncrMeshIndexing", | |||
"BatchedSetMeshIndexing", | |||
) |
@@ -11,25 +11,12 @@ from typing import Union | |||
from ..._imperative_rt import OpDef, ops | |||
from ...tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | |||
from .._internal import all_ops | |||
from .._internal.helper import PodOpVisitor | |||
# register OpDef as a "virtual subclass" of OpBase, so any of registered | |||
# apply(OpBase, ...) rules could work well on OpDef | |||
OpBase.register(OpDef) | |||
# forward to apply(OpDef, ...) | |||
@apply.register() | |||
def _(op: PodOpVisitor, *args: Union[TensorBase, TensorWrapperBase]): | |||
return apply(op.to_c(), *args) | |||
__all__ = ["OpDef", "PodOpVisitor"] | |||
for k, v in all_ops.__dict__.items(): | |||
if isinstance(v, type) and issubclass(v, PodOpVisitor): | |||
globals()[k] = v | |||
__all__.append(k) | |||
__all__ = ["OpDef"] | |||
for k, v in ops.__dict__.items(): | |||
if isinstance(v, type) and issubclass(v, OpDef): | |||
@@ -90,7 +90,7 @@ def _reshape(x, shape): | |||
if unspec_axis is None: | |||
op = builtin.Reshape() | |||
else: | |||
op = builtin.Reshape(unspec_axis=unspec_axis) | |||
op = builtin.Reshape(axis=unspec_axis) | |||
(x,) = apply(op, x, shape) | |||
return x | |||
@@ -144,8 +144,6 @@ def _logical_binary_elwise(mode, rev=False): | |||
def _remove_axis(inp: Tensor, axis) -> Tensor: | |||
Param = builtin.AxisAddRemove.Param | |||
def get_axes(): | |||
if axis is None: | |||
return [i for i, s in enumerate(inp.shape) if s == 1] | |||
@@ -159,8 +157,7 @@ def _remove_axis(inp: Tensor, axis) -> Tensor: | |||
axis = sorted(i + inp.ndim if i < 0 else i for i in axis) | |||
axis = [a - i for i, a in enumerate(axis)] | |||
param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_remove, axis)) | |||
op = builtin.AxisAddRemove(param=param) | |||
op = builtin.RemoveAxis(axis=axis) | |||
(result,) = apply(op, inp) | |||
if len(axis) == inp.ndim: | |||
setscalar(result) | |||
@@ -134,7 +134,7 @@ def astype(x, dtype): | |||
dtype = np.dtype(dtype) | |||
if not is_equal(x.dtype, dtype): | |||
isscalar = x.__wrapped__._data._isscalar | |||
(x,) = apply(builtin.TypeCvt(param=dtype), x) | |||
(x,) = apply(builtin.TypeCvt(dtype=dtype), x) | |||
x.__wrapped__._data._isscalar = isscalar | |||
return x | |||
@@ -8,7 +8,6 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Optional, Tuple | |||
from ..core._imperative_rt.ops import CollectiveCommMode | |||
from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn | |||
from ..core.autodiff.grad import ( | |||
Tracer, | |||
@@ -110,17 +109,20 @@ def collective_comm(inp, mode, group, device): | |||
assert isinstance(group, Group) | |||
if group is None: | |||
return inp | |||
op = CollectiveComm() | |||
op.key = group.key | |||
op.nr_devices = group.size | |||
op.rank = group.rank | |||
op.is_root = op.rank == 0 | |||
op.local_grad = False | |||
op.addr, op.port = get_mm_server_addr() | |||
op.mode = mode | |||
op.dtype = inp.dtype | |||
op.backend = get_backend() | |||
op.comp_node = device | |||
addr, port = get_mm_server_addr() | |||
op = CollectiveComm( | |||
key=group.key, | |||
nr_devices=group.size, | |||
rank=group.rank, | |||
is_root=(group.rank == 0), | |||
local_grad=False, | |||
addr=addr, | |||
port=port, | |||
mode=mode, | |||
dtype=inp.dtype, | |||
backend=get_backend(), | |||
comp_node=device, | |||
) | |||
return apply(op, inp)[0] | |||
@@ -134,7 +136,7 @@ def reduce_sum( | |||
:param group: communication group. | |||
:param device: execution device. | |||
""" | |||
mode = CollectiveCommMode.REDUCE_SUM | |||
mode = CollectiveComm.Mode.REDUCE_SUM | |||
return collective_comm(inp, mode, group, device) | |||
@@ -148,7 +150,7 @@ def broadcast( | |||
:param group: communication group. | |||
:param device: execution device. | |||
""" | |||
mode = CollectiveCommMode.BROADCAST | |||
mode = CollectiveComm.Mode.BROADCAST | |||
return collective_comm(inp, mode, group, device) | |||
@@ -162,7 +164,7 @@ def all_gather( | |||
:param group: communication group. | |||
:param device: execution device. | |||
""" | |||
mode = CollectiveCommMode.ALL_GATHER | |||
mode = CollectiveComm.Mode.ALL_GATHER | |||
return collective_comm(inp, mode, group, device) | |||
@@ -176,7 +178,7 @@ def reduce_scatter_sum( | |||
:param group: communication group. | |||
:param device: execution device. | |||
""" | |||
mode = CollectiveCommMode.REDUCE_SCATTER_SUM | |||
mode = CollectiveComm.Mode.REDUCE_SCATTER_SUM | |||
return collective_comm(inp, mode, group, device) | |||
@@ -190,7 +192,7 @@ def all_reduce_sum( | |||
:param group: communication group. | |||
:param device: execution device. | |||
""" | |||
mode = CollectiveCommMode.ALL_REDUCE_SUM | |||
mode = CollectiveComm.Mode.ALL_REDUCE_SUM | |||
return collective_comm(inp, mode, group, device) | |||
@@ -204,7 +206,7 @@ def all_reduce_max( | |||
:param group: communication group. | |||
:param device: execution device. | |||
""" | |||
mode = CollectiveCommMode.ALL_REDUCE_MAX | |||
mode = CollectiveComm.Mode.ALL_REDUCE_MAX | |||
return collective_comm(inp, mode, group, device) | |||
@@ -218,7 +220,7 @@ def all_reduce_min( | |||
:param group: communication group. | |||
:param device: execution device. | |||
""" | |||
mode = CollectiveCommMode.ALL_REDUCE_MIN | |||
mode = CollectiveComm.Mode.ALL_REDUCE_MIN | |||
return collective_comm(inp, mode, group, device) | |||
@@ -232,7 +234,7 @@ def gather( | |||
:param group: communication group. | |||
:param device: execution device. | |||
""" | |||
mode = CollectiveCommMode.GATHER | |||
mode = CollectiveComm.Mode.GATHER | |||
return collective_comm(inp, mode, group, device) | |||
@@ -246,7 +248,7 @@ def scatter( | |||
:param group: communication group. | |||
:param device: execution device. | |||
""" | |||
mode = CollectiveCommMode.SCATTER | |||
mode = CollectiveComm.Mode.SCATTER | |||
return collective_comm(inp, mode, group, device) | |||
@@ -260,7 +262,7 @@ def all_to_all( | |||
:param group: communication group. | |||
:param device: execution device. | |||
""" | |||
mode = CollectiveCommMode.ALL_TO_ALL | |||
mode = CollectiveComm.Mode.ALL_TO_ALL | |||
return collective_comm(inp, mode, group, device) | |||
@@ -73,27 +73,7 @@ __all__ = [ | |||
] | |||
class _ElemwiseMode(Elemwise.Mode): | |||
@classmethod | |||
def __normalize(cls, val): | |||
if isinstance(val, str): | |||
if not hasattr(cls, "__member_upper_dict__"): | |||
cls.__member_upper_dict__ = { | |||
k.upper(): v for k, v in cls.__members__.items() | |||
} | |||
val = cls.__member_upper_dict__.get(val.upper(), val) | |||
return val | |||
@classmethod | |||
def convert(cls, val): | |||
val = cls.__normalize(val) | |||
if isinstance(val, cls): | |||
return val | |||
return cls(val) | |||
def _elwise(*args, mode): | |||
mode = _ElemwiseMode.convert(mode) | |||
op = builtin.Elemwise(mode) | |||
tensor_args = list( | |||
filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args) | |||
@@ -13,7 +13,6 @@ import numbers | |||
from typing import Optional, Sequence, Tuple, Union | |||
from ..core.ops import builtin | |||
from ..core.ops._internal import param_defs as P | |||
from ..core.ops.special import Const | |||
from ..core.tensor import utils | |||
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | |||
@@ -601,9 +600,9 @@ def argsort(inp: Tensor, descending: bool = False) -> Tensor: | |||
""" | |||
assert len(inp.shape) <= 2, "Input should be 1d or 2d" | |||
if descending: | |||
order = P.Argsort.Order.DESCENDING | |||
order = "DESCENDING" | |||
else: | |||
order = P.Argsort.Order.ASCENDING | |||
order = "ASCENDING" | |||
op = builtin.Argsort(order=order) | |||
if len(inp.shape) == 1: | |||
@@ -643,9 +642,9 @@ def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]: | |||
""" | |||
assert len(inp.shape) <= 2, "Input should be 1d or 2d" | |||
if descending: | |||
order = P.Argsort.Order.DESCENDING | |||
order = "DESCENDING" | |||
else: | |||
order = P.Argsort.Order.ASCENDING | |||
order = "ASCENDING" | |||
op = builtin.Argsort(order=order) | |||
if len(inp.shape) == 1: | |||
@@ -695,13 +694,12 @@ def topk( | |||
if descending: | |||
inp = -inp | |||
Mode = P.TopK.Mode | |||
if kth_only: | |||
mode = Mode.KTH_ONLY | |||
mode = "KTH_ONLY" | |||
elif no_sort: | |||
mode = Mode.VALUE_IDX_NOSORT | |||
mode = "VALUE_IDX_NOSORT" | |||
else: | |||
mode = Mode.VALUE_IDX_SORTED | |||
mode = "VALUE_IDX_SORTED" | |||
op = builtin.TopK(mode=mode) | |||
if not isinstance(k, (TensorBase, TensorWrapperBase)): | |||
@@ -12,7 +12,6 @@ from typing import Optional, Sequence, Tuple, Union | |||
from ..core._imperative_rt import CompNode | |||
from ..core._trace_option import use_symbolic_shape | |||
from ..core.ops import builtin | |||
from ..core.ops._internal import param_defs as P | |||
from ..core.ops.builtin import BatchNorm | |||
from ..core.ops.special import Const | |||
from ..core.tensor import megbrain_graph, utils | |||
@@ -121,11 +120,11 @@ def conv2d( | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
and the shape of weight should be `(groups, out_channel // groups, | |||
in_channels // groups, height, width)`. | |||
:type conv_mode: string or :class:`P.Convolution.Mode` | |||
:type conv_mode: string or :class:`Convolution.Mode` | |||
:param conv_mode: supports "CROSS_CORRELATION". Default: | |||
"CROSS_CORRELATION" | |||
:type compute_mode: string or | |||
:class:`P.Convolution.ComputeMode` | |||
:class:`Convolution.ComputeMode` | |||
:param compute_mode: when set to "DEFAULT", no special requirements will be | |||
placed on the precision of intermediate results. When set to "FLOAT32", | |||
"Float32" would be used for accumulator and intermediate result, but only | |||
@@ -139,8 +138,8 @@ def conv2d( | |||
pad_h, pad_w = expand_hw(padding) | |||
dilate_h, dilate_w = expand_hw(dilation) | |||
Sparse = P.Convolution.Sparse | |||
sparse_type = Sparse.DENSE if groups == 1 else Sparse.GROUP | |||
Sparse = builtin.Convolution.Sparse | |||
sparse_type = "DENSE" if groups == 1 else "GROUP" | |||
op = builtin.Convolution( | |||
stride_h=stride_h, | |||
stride_w=stride_w, | |||
@@ -187,11 +186,11 @@ def conv_transpose2d( | |||
``in_channels`` and ``out_channels`` must be divisible by groups, | |||
and the shape of weight should be `(groups, out_channel // groups, | |||
in_channels // groups, height, width)`. Default: 1 | |||
:type conv_mode: string or :class:`P.Convolution.Mode` | |||
:type conv_mode: string or :class:`Convolution.Mode` | |||
:param conv_mode: supports "CROSS_CORRELATION". Default: | |||
"CROSS_CORRELATION" | |||
:type compute_mode: string or | |||
:class:`P.Convolution.ComputeMode` | |||
:class:`Convolution.ComputeMode` | |||
:param compute_mode: when set to "DEFAULT", no special requirements will be | |||
placed on the precision of intermediate results. When set to "FLOAT32", | |||
"Float32" would be used for accumulator and intermediate result, but only | |||
@@ -240,8 +239,6 @@ def local_conv2d( | |||
pad_h, pad_w = expand_hw(padding) | |||
dilate_h, dilate_w = expand_hw(dilation) | |||
Sparse = P.Convolution.Sparse | |||
op = builtin.GroupLocal( | |||
stride_h=stride_h, | |||
stride_w=stride_w, | |||
@@ -251,7 +248,7 @@ def local_conv2d( | |||
dilate_w=dilate_w, | |||
mode=conv_mode, | |||
compute_mode="DEFAULT", | |||
sparse=Sparse.DENSE, | |||
sparse="DENSE", | |||
) | |||
inp, weight = utils.convert_inputs(inp, weight) | |||
(output,) = apply(op, inp, weight) | |||
@@ -696,19 +693,14 @@ def batch_norm( | |||
if not training: | |||
op = builtin.BatchNorm( | |||
BatchNorm.ParamDim.DIM_1C11, BatchNorm.FwdMode.INFERENCE, eps, 1.0, 1.0, 0.0 | |||
fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps, param_dim="DIM_1C11" | |||
) | |||
ret = apply(op, inp, weight, bias, running_mean, running_var)[-1] | |||
return ret | |||
else: | |||
op = builtin.BatchNorm( | |||
BatchNorm.ParamDim.DIM_1C11, | |||
BatchNorm.FwdMode.TRAINING, | |||
eps, | |||
1.0 - momentum, | |||
1.0, | |||
0.0, | |||
avg_factor=1 - momentum, epsilon=eps, param_dim="DIM_1C11" | |||
) | |||
if has_mean or has_var: | |||
running_mean = make_full_if_none(running_mean, 0) | |||
@@ -1638,8 +1630,7 @@ def conv1d( | |||
pad_h = padding | |||
dilate_h = dilation | |||
Sparse = P.Convolution.Sparse | |||
sparse_type = Sparse.DENSE if groups == 1 else Sparse.GROUP | |||
sparse_type = "DENSE" if groups == 1 else "GROUP" | |||
op = builtin.Convolution( | |||
stride_h=stride_h, | |||
stride_w=1, | |||
@@ -41,12 +41,12 @@ def conv_bias_activation( | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
and the shape of weight should be `(groups, out_channel // groups, | |||
in_channels // groups, height, width)`. | |||
:type conv_mode: string or :class:`P.Convolution.Mode`. | |||
:type conv_mode: string or :class:`Convolution.Mode`. | |||
:param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | |||
'CROSS_CORRELATION' | |||
:param dtype: support for ``np.dtype``, Default: np.int8 | |||
:type compute_mode: string or | |||
:class:`P.Convolution.ComputeMode`. | |||
:class:`Convolution.ComputeMode`. | |||
:param compute_mode: when set to "DEFAULT", no special requirements will be | |||
placed on the precision of intermediate results. When set to "FLOAT32", | |||
"Float32" would be used for accumulator and intermediate result, but only effective when input and output are of Float16 dtype. | |||
@@ -56,7 +56,7 @@ def conv_bias_activation( | |||
sh, sw = _pair_nonzero(stride) | |||
dh, dw = _pair_nonzero(dilation) | |||
sparse_type = "DENSE" if groups == 1 else "GROUP" | |||
op = builtin.ConvBiasForward( | |||
op = builtin.ConvBias( | |||
stride_h=sh, | |||
stride_w=sw, | |||
pad_h=ph, | |||
@@ -101,12 +101,12 @@ def batch_conv_bias_activation( | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
and the shape of weight should be `(groups, out_channel // groups, | |||
in_channels // groups, height, width)`. | |||
:type conv_mode: string or :class:`P.Convolution.Mode`. | |||
:type conv_mode: string or :class:`Convolution.Mode`. | |||
:param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | |||
'CROSS_CORRELATION' | |||
:param dtype: support for ``np.dtype``, Default: np.int8 | |||
:type compute_mode: string or | |||
:class:`P.Convolution.ComputeMode`. | |||
:class:`Convolution.ComputeMode`. | |||
:param compute_mode: when set to "DEFAULT", no special requirements will be | |||
placed on the precision of intermediate results. When set to "FLOAT32", | |||
"Float32" would be used for accumulator and intermediate result, but only effective when input and output are of Float16 dtype. | |||
@@ -116,7 +116,7 @@ def batch_conv_bias_activation( | |||
sh, sw = _pair_nonzero(stride) | |||
dh, dw = _pair_nonzero(dilation) | |||
sparse_type = "DENSE" if groups == 1 else "GROUP" | |||
op = builtin.BatchConvBiasForward( | |||
op = builtin.BatchConvBias( | |||
stride_h=sh, | |||
stride_w=sw, | |||
pad_h=ph, | |||
@@ -16,7 +16,6 @@ import numpy as np | |||
from ..core._imperative_rt import CompNode | |||
from ..core._wrap import device as as_device | |||
from ..core.ops import builtin | |||
from ..core.ops._internal import param_defs as P | |||
from ..core.ops.special import Const | |||
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | |||
from ..core.tensor.tensor_wrapper import _broadcast, _remove_axis | |||
@@ -722,7 +721,7 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor: | |||
[1 0]] | |||
""" | |||
return inp.transpose(pattern) | |||
return inp.transpose(list(-1 if _ == "x" else _ for _ in pattern)) | |||
def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor: | |||
@@ -756,10 +755,6 @@ def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor: | |||
return inp.reshape(target_shape) | |||
AxisAddRemove = builtin.AxisAddRemove | |||
AxisDesc = AxisAddRemove.AxisDesc | |||
def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor: | |||
r""" | |||
Reshapes the tensor by flattening the sub-tensor from dimension ``start_axis`` to dimension ``end_axis``. | |||
@@ -826,7 +821,6 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: | |||
(1, 2) | |||
""" | |||
Param = builtin.AxisAddRemove.Param | |||
def get_axes(): | |||
try: | |||
@@ -839,8 +833,7 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: | |||
ndim = inp.ndim + len(axis) | |||
axis = sorted(i + ndim if i < 0 else i for i in axis) | |||
param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_add, axis)) | |||
op = builtin.AxisAddRemove(param=param) | |||
op = builtin.AddAxis(axis=axis) | |||
(result,) = apply(op, inp) | |||
return result | |||
@@ -21,9 +21,10 @@ import numpy as np | |||
from ..core._imperative_rt import GraphProfiler | |||
from ..core._imperative_rt.ops import ( | |||
CollectiveComm, | |||
OprAttr, | |||
GaussianRNG, | |||
RemoteRecv, | |||
RemoteSend, | |||
UniformRNG, | |||
VirtualDep, | |||
) | |||
from ..core._trace_option import set_symbolic_shape | |||
@@ -182,14 +183,7 @@ class trace: | |||
record = self._seq[self._pc] | |||
op_, ihandles, ohandles = record | |||
if op != op_: | |||
# FIXME: will be removed once better rng implementation is done | |||
if isinstance(op, OprAttr) and ( | |||
op.type in ("UniformRNG", "GaussianRNG") and op.type == op_.type | |||
): | |||
if op.param[8:] != op_.param[8:]: | |||
raise TraceMismatchError("op different from last time") | |||
else: | |||
raise TraceMismatchError("op different from last time") | |||
raise TraceMismatchError("op different from last time") | |||
if len(ihandles) != len(args): | |||
raise TraceMismatchError("op input size different from last time") | |||
@@ -10,7 +10,6 @@ from typing import Tuple, Union | |||
import numpy as np | |||
from ..core.ops._internal import param_defs as P | |||
from ..functional import conv1d, conv2d, conv_transpose2d, local_conv2d, relu | |||
from ..functional.types import _pair, _pair_nonzero | |||
from ..tensor import Parameter | |||
@@ -156,8 +155,6 @@ class Conv1d(_ConvNd): | |||
(2, 1, 2) | |||
""" | |||
_conv_mode_type = P.Convolution.Mode | |||
_compute_mode_type = P.Convolution.ComputeMode | |||
def __init__( | |||
self, | |||
@@ -176,8 +173,8 @@ class Conv1d(_ConvNd): | |||
stride = stride | |||
padding = padding | |||
dilation = dilation | |||
self.conv_mode = self._conv_mode_type.convert(conv_mode) | |||
self.compute_mode = self._compute_mode_type.convert(compute_mode) | |||
self.conv_mode = conv_mode | |||
self.compute_mode = compute_mode | |||
super().__init__( | |||
in_channels, | |||
out_channels, | |||
@@ -302,9 +299,6 @@ class Conv2d(_ConvNd): | |||
""" | |||
_conv_mode_type = P.Convolution.Mode | |||
_compute_mode_type = P.Convolution.ComputeMode | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
@@ -322,8 +316,8 @@ class Conv2d(_ConvNd): | |||
stride = _pair_nonzero(stride) | |||
padding = _pair(padding) | |||
dilation = _pair_nonzero(dilation) | |||
self.conv_mode = self._conv_mode_type.convert(conv_mode) | |||
self.compute_mode = self._compute_mode_type.convert(compute_mode) | |||
self.conv_mode = conv_mode | |||
self.compute_mode = compute_mode | |||
super().__init__( | |||
in_channels, | |||
out_channels, | |||
@@ -414,9 +408,6 @@ class ConvTranspose2d(_ConvNd): | |||
effective when input and output are of float16 dtype. | |||
""" | |||
_conv_mode_type = P.Convolution.Mode | |||
_compute_mode_type = P.Convolution.ComputeMode | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
@@ -434,8 +425,8 @@ class ConvTranspose2d(_ConvNd): | |||
stride = _pair_nonzero(stride) | |||
padding = _pair(padding) | |||
dilation = _pair_nonzero(dilation) | |||
self.conv_mode = self._conv_mode_type.convert(conv_mode) | |||
self.compute_mode = self._compute_mode_type.convert(compute_mode) | |||
self.conv_mode = conv_mode | |||
self.compute_mode = compute_mode | |||
super().__init__( | |||
in_channels, | |||
out_channels, | |||
@@ -509,8 +500,6 @@ class LocalConv2d(Conv2d): | |||
in_channels // groups, *kernel_size, out_channels // groups)`. | |||
""" | |||
_conv_mode_type = P.Convolution.Mode | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
@@ -5,7 +5,6 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ..core.ops._internal import param_defs as P | |||
from ..functional.elemwise import _elwise | |||
from ..tensor import Tensor | |||
from .module import Module | |||
@@ -41,8 +41,8 @@ class Conv2d(Float.Conv2d, QATModule): | |||
float_module.dilation, | |||
float_module.groups, | |||
float_module.bias is not None, | |||
float_module.conv_mode.name, | |||
float_module.compute_mode.name, | |||
float_module.conv_mode, | |||
float_module.compute_mode, | |||
) | |||
qat_module.weight = float_module.weight | |||
qat_module.bias = float_module.bias | |||
@@ -5,7 +5,6 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ...core.ops._internal import param_defs as P | |||
from ...functional.elemwise import _elemwise_multi_type | |||
from ...tensor import Tensor | |||
from ..qat import elemwise as QAT | |||
@@ -15,11 +14,9 @@ from .module import QuantizedModule | |||
class Elemwise(QuantizedModule): | |||
r"""Quantized version of :class:`~.qat.elemwise.Elemwise`.""" | |||
_elemwise_multi_type_mode = P.ElemwiseMultiType.Mode | |||
def __init__(self, method, dtype=None): | |||
super().__init__() | |||
self.method = self._elemwise_multi_type_mode.convert("Q" + method) | |||
self.method = "Q" + method | |||
self.output_dtype = dtype | |||
def forward(self, *inps): | |||
@@ -15,7 +15,7 @@ from typing import Iterable, List, Optional | |||
from ..core._imperative_rt import OperatorNodeConfig, ProfileEntry | |||
from ..core._imperative_rt import ProfilerImpl as _Profiler | |||
from ..core._imperative_rt.imperative import sync | |||
from ..core._imperative_rt.ops import CollectiveCommMode | |||
from ..core._imperative_rt.ops import CollectiveComm | |||
def _make_dict(**kwargs): | |||
@@ -194,7 +194,7 @@ class Profiler: | |||
_type_map = { | |||
OperatorNodeConfig: lambda x: _print_opnode_config(x), | |||
bytes: lambda x: base64.encodebytes(x).decode("ascii"), | |||
CollectiveCommMode: lambda x: str(x), | |||
CollectiveComm.Mode: lambda x: str(x), | |||
} | |||
_dumper_map = { | |||
@@ -421,9 +421,7 @@ void init_graph_rt(py::module m) { | |||
common.def("invoke_op", [](const OpDef& def, const std::vector<cg::VarNode*> inputs, cg::ComputingGraph* graph) { | |||
cg::VarNodeArray vinputs(inputs.begin(), inputs.end()); | |||
auto opr = OpDef::apply_on_var_node(def, vinputs); | |||
auto outputs = opr->usable_output(); | |||
return to_tuple(outputs); | |||
return to_tuple(OpDef::apply_on_var_node(def, vinputs)); | |||
}, | |||
py::arg(), py::arg(), py::arg("graph") = py::none()); | |||
@@ -109,9 +109,6 @@ void init_imperative_rt(py::module m) { | |||
py::class_<OpDef, std::shared_ptr<OpDef>>(m, "OpDef") | |||
.def("ctype", [](const OpDef& opdef) { | |||
if (auto attr = opdef.try_cast_final<OprAttr>()) { | |||
return attr->type.c_str(); | |||
} | |||
return opdef.dyn_typeinfo()->name; | |||
}) | |||
.def("__eq__", [](const OpDef& lhs, const OpDef& rhs) { | |||
@@ -14,41 +14,29 @@ | |||
#include "megbrain/imperative.h" | |||
#include "megbrain/imperative/ops/backward_graph.h" | |||
#include "megbrain/imperative/ops/opr_attr.h" | |||
#include "megbrain/imperative/ops/tensor_manip.h" | |||
#include "megbrain/imperative/ops/collective_comm.h" | |||
#include "megbrain/imperative/ops/io_remote.h" | |||
#include "megbrain/imperative/ops/cond_take.h" | |||
#include "megbrain/imperative/ops/nms.h" | |||
#include "megbrain/imperative/ops/elemwise.h" | |||
#include "megbrain/imperative/ops/batch_norm.h" | |||
#include "megbrain/imperative/ops/broadcast.h" | |||
#include "megbrain/imperative/ops/utility.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
namespace py = pybind11; | |||
namespace { | |||
auto normalize_enum(const std::string& in) { | |||
std::string ret; | |||
for (auto&& c : in) { | |||
ret += toupper(c); | |||
} | |||
return ret; | |||
} | |||
} // anonymous namespace | |||
void init_ops(py::module m) { | |||
using namespace mgb::imperative; | |||
py::class_<OprAttr, std::shared_ptr<OprAttr>, OpDef>(m, "OprAttr") | |||
.def(py::init<>()) | |||
.def_readwrite("type", &OprAttr::type) | |||
.def_readwrite("param", &OprAttr::param) | |||
.def_readwrite("config", &OprAttr::config) | |||
.def_property("param", | |||
[](const OprAttr& attr) -> py::bytes { | |||
return std::string(attr.param.begin(), attr.param.end()); | |||
}, | |||
[] (OprAttr& attr, py::bytes data) { | |||
auto s = py::cast<std::string>(data); | |||
attr.param.clear(); | |||
attr.param.insert(attr.param.end(), s.begin(), s.end()); | |||
}); | |||
py::class_<BackwardGraph, std::shared_ptr<BackwardGraph>, OpDef>(m, "BackwardGraph") | |||
.def("interpret", [](BackwardGraph& self, py::object pyf, py::object pyc, | |||
const mgb::SmallVector<py::object>& inputs) { | |||
auto f = [pyf](OpDef& op, const mgb::SmallVector<py::object>& inputs) { | |||
return py::cast<mgb::SmallVector<py::object>>(pyf(op.copy(), inputs)); | |||
return py::cast<mgb::SmallVector<py::object>>(pyf(op.shared_from_this(), inputs)); | |||
}; | |||
auto c = [pyc](const TensorPtr& tensor) { | |||
return pyc(tensor->dev_tensor()); | |||
@@ -56,162 +44,8 @@ void init_ops(py::module m) { | |||
return self.graph().interpret<py::object>(f, c, inputs); | |||
}); | |||
py::class_<GetVarShape, std::shared_ptr<GetVarShape>, OpDef>(m, "GetVarShape") | |||
.def(py::init()); | |||
#define V(m) .value(#m, CollectiveComm::Mode::m) | |||
py::enum_<CollectiveComm::Mode>(m, "CollectiveCommMode") | |||
V(REDUCE_SUM) | |||
V(BROADCAST) | |||
V(ALL_GATHER) | |||
V(REDUCE_SCATTER_SUM) | |||
V(ALL_REDUCE_SUM) | |||
V(ALL_REDUCE_MAX) | |||
V(ALL_REDUCE_MIN) | |||
V(ALL_REDUCE_PROD) | |||
V(GATHER) | |||
V(SCATTER) | |||
V(ALL_TO_ALL); | |||
#undef V | |||
py::class_<CollectiveComm, std::shared_ptr<CollectiveComm>, OpDef>(m, "CollectiveComm") | |||
.def(py::init<>()) | |||
.def_readwrite("key", &CollectiveComm::key) | |||
.def_readwrite("nr_devices", &CollectiveComm::nr_devices) | |||
.def_readwrite("rank", &CollectiveComm::rank) | |||
.def_readwrite("is_root", &CollectiveComm::is_root) | |||
.def_readwrite("local_grad", &CollectiveComm::local_grad) | |||
.def_readwrite("addr", &CollectiveComm::addr) | |||
.def_readwrite("port", &CollectiveComm::port) | |||
.def_readwrite("mode", &CollectiveComm::mode) | |||
.def_readwrite("dtype", &CollectiveComm::dtype) | |||
.def_readwrite("backend", &CollectiveComm::backend) | |||
.def_readwrite("comp_node", &CollectiveComm::comp_node); | |||
py::class_<RemoteSend, std::shared_ptr<RemoteSend>, OpDef>(m, "RemoteSend") | |||
.def(py::init<>()) | |||
.def_readwrite("key", &RemoteSend::key) | |||
.def_readwrite("addr", &RemoteSend::addr) | |||
.def_readwrite("port", &RemoteSend::port) | |||
.def_readwrite("rank_to", &RemoteSend::rank_to); | |||
py::class_<RemoteRecv, std::shared_ptr<RemoteRecv>, OpDef>(m, "RemoteRecv") | |||
.def(py::init<>()) | |||
.def_readwrite("key", &RemoteRecv::key) | |||
.def_readwrite("addr", &RemoteRecv::addr) | |||
.def_readwrite("port", &RemoteRecv::port) | |||
.def_readwrite("rank_from", &RemoteRecv::rank_from) | |||
.def_readwrite("shape", &RemoteRecv::shape) | |||
.def_readwrite("cn", &RemoteRecv::cn) | |||
.def_readwrite("dtype", &RemoteRecv::dtype); | |||
py::class_<ParamPackSplit, std::shared_ptr<ParamPackSplit>, OpDef>(m, "ParamPackSplit") | |||
.def(py::init<>()) | |||
.def_readwrite("offsets", &ParamPackSplit::offsets) | |||
.def_readwrite("shapes", &ParamPackSplit::shapes); | |||
py::class_<ParamPackConcat, std::shared_ptr<ParamPackConcat>, OpDef>(m, "ParamPackConcat") | |||
.def(py::init<>()) | |||
.def_readwrite("offsets", &ParamPackConcat::offsets); | |||
py::class_<VirtualDep, std::shared_ptr<VirtualDep>, OpDef>(m, "VirtualDep") | |||
.def(py::init<>()); | |||
py::class_<CondTake, std::shared_ptr<CondTake>, OpDef>(m, "CondTake") | |||
.def(py::init<>()); | |||
py::class_<NMSKeep, std::shared_ptr<NMSKeep>, OpDef>(m, "NMSKeep") | |||
.def(py::init<float, uint32_t>()) | |||
.def_readwrite("iou_thresh", &NMSKeep::iou_thresh) | |||
.def_readwrite("max_output", &NMSKeep::max_output); | |||
py::class_<Elemwise, std::shared_ptr<Elemwise>, OpDef> elemwise(m, "Elemwise"); | |||
elemwise.def(py::init<Elemwise::Mode>()) | |||
.def_readwrite("mode", &Elemwise::mode); | |||
#define V(m) .value(#m, Elemwise::Mode::m) | |||
py::enum_<Elemwise::Mode>(elemwise, "Mode") | |||
V(RELU) | |||
V(ABS) | |||
V(ACOS) | |||
V(ASIN) | |||
V(CEIL) | |||
V(COS) | |||
V(EXP) | |||
V(EXPM1) | |||
V(FLOOR) | |||
V(LOG) | |||
V(LOG1P) | |||
V(NEGATE) | |||
V(SIGMOID) | |||
V(SIN) | |||
V(TANH) | |||
V(ABS_GRAD) | |||
V(ADD) | |||
V(FLOOR_DIV) | |||
V(MAX) | |||
V(MIN) | |||
V(MOD) | |||
V(MUL) | |||
V(POW) | |||
V(SIGMOID_GRAD) | |||
V(SUB) | |||
V(SWITCH_GT0) | |||
V(TANH_GRAD) | |||
V(TRUE_DIV) | |||
V(LOG_SUM_EXP) | |||
V(LT) | |||
V(LEQ) | |||
V(EQ) | |||
V(SHL) | |||
V(SHR) | |||
V(COND_LEQ_MOV) | |||
V(FUSE_MUL_ADD3) | |||
V(FUSE_MUL_ADD4) | |||
V(FUSE_ADD_RELU) | |||
V(FUSE_ADD_SIGMOID) | |||
V(FUSE_ADD_TANH) | |||
V(FAST_TANH) | |||
V(FAST_TANH_GRAD) | |||
V(ROUND) | |||
V(RMULH) | |||
V(ATAN2) | |||
V(ERF) | |||
V(ERFINV) | |||
V(ERFC) | |||
V(ERFCINV) | |||
V(H_SWISH) | |||
V(H_SWISH_GRAD) | |||
V(FUSE_ADD_H_SWISH) | |||
V(NOT) | |||
V(AND) | |||
V(OR) | |||
V(XOR); | |||
#undef V | |||
py::class_<BatchNorm, std::shared_ptr<BatchNorm>, OpDef> batchnorm(m, "BatchNorm"); | |||
batchnorm.def(py::init<const BatchNorm::Param::ParamDim&, const BatchNorm::Param::FwdMode&, double, double, float, float>()) | |||
.def_readwrite("param_dim", &BatchNorm::param_dim) | |||
.def_readwrite("fwd_mode", &BatchNorm::fwd_mode) | |||
.def_readwrite("epsilon", &BatchNorm::epsilon) | |||
.def_readwrite("avg_factor", &BatchNorm::avg_factor) | |||
.def_readwrite("scale", &BatchNorm::scale) | |||
.def_readwrite("bias", &BatchNorm::bias); | |||
#define V(m) .value(#m, BatchNorm::Param::ParamDim::m) | |||
py::enum_<BatchNorm::Param::ParamDim>(batchnorm, "ParamDim") | |||
V(DIM_11HW) | |||
V(DIM_1CHW) | |||
V(DIM_1C11); | |||
#undef V | |||
#define V(m) .value(#m, BatchNorm::Param::FwdMode::m) | |||
py::enum_<BatchNorm::Param::FwdMode>(batchnorm, "FwdMode") | |||
V(TRAINING) | |||
V(INFERENCE); | |||
#undef V | |||
py::class_<Broadcast, std::shared_ptr<Broadcast>, OpDef>(m, "Broadcast") | |||
.def(py::init<>()); | |||
#include "opdef.py.inl" | |||
} |
@@ -113,7 +113,7 @@ def test_quint8_typecvt(): | |||
data = np.random.random(shape).astype(np.float32) * 5 - 1 | |||
def typecvt(x, dt=None): | |||
(y,) = apply(ops.TypeCvt(param=dt), x) | |||
(y,) = apply(ops.TypeCvt(dtype=dt), x) | |||
return y | |||
# convert to quint8 | |||
@@ -194,7 +194,7 @@ def test_quint4_typecvt(): | |||
data = np.random.random(shape).astype(np.float32) * 5 - 1 | |||
def typecvt(x, dt=None): | |||
(y,) = apply(ops.TypeCvt(param=dt), x) | |||
(y,) = apply(ops.TypeCvt(dtype=dt), x) | |||
return y | |||
# convert to quint4 | |||
@@ -11,10 +11,9 @@ import collections | |||
import numpy as np | |||
import pytest | |||
import megengine.core.ops.builtin | |||
import megengine.core.tensor.raw_tensor | |||
from megengine.core._trace_option import use_symbolic_shape | |||
from megengine.core.ops._internal import all_ops | |||
from megengine.core.ops import builtin | |||
from megengine.core.tensor import Tensor | |||
from megengine.core.tensor.core import apply | |||
from megengine.core.tensor.raw_tensor import RawTensor, as_raw_tensor | |||
@@ -105,7 +104,7 @@ def canonize_inputs(inputs, *, config): | |||
need_cvt = False | |||
for i in old_inputs: | |||
if isinstance(i, RawTensor): | |||
get_comp_node = lambda cn=i.device.to_c(): cn | |||
get_comp_node = lambda cn=i.device: cn | |||
else: | |||
need_cvt = True | |||
inputs.append(i) | |||
@@ -193,91 +192,91 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||
def transpose(*args, **kwargs): | |||
op = all_ops.Dimshuffle(**kwargs).to_c() | |||
op = builtin.Dimshuffle(**kwargs) | |||
return invoke_op(op, args) | |||
def broadcast(input, tshape): | |||
op = all_ops.Broadcast().to_c() | |||
op = builtin.Broadcast() | |||
return invoke_op(op, (input, tshape), canonize_reshape) | |||
def subtensor(input, tuple_val): | |||
input, tensors, items = unpack_getitem(input, tuple_val) | |||
op = all_ops.Subtensor(items).to_c() | |||
op = builtin.Subtensor(items) | |||
return invoke_op(op, (input, *tensors)) | |||
def set_subtensor(input, value, tuple_val): | |||
input, tensors, items = unpack_getitem(input, tuple_val) | |||
op = all_ops.SetSubtensor(items).to_c() | |||
op = builtin.SetSubtensor(items) | |||
return invoke_op(op, (input, value, *tensors)) | |||
def incr_subtensor(input, value, tuple_val): | |||
input, tensors, items = unpack_getitem(input, tuple_val) | |||
op = all_ops.IncrSubtensor(items).to_c() | |||
op = builtin.IncrSubtensor(items) | |||
return invoke_op(op, (input, value, *tensors)) | |||
def advance_indexing(input, tuple_val): | |||
input, tensors, items = unpack_getitem(input, tuple_val) | |||
op = all_ops.IndexingMultiAxisVec(items).to_c() | |||
op = builtin.IndexingMultiAxisVec(items) | |||
return invoke_op(op, (input, *tensors)) | |||
def set_advance_indexing(input, value, tuple_val): | |||
input, tensors, items = unpack_getitem(input, tuple_val) | |||
op = all_ops.IndexingSetMultiAxisVec(items).to_c() | |||
op = builtin.IndexingSetMultiAxisVec(items) | |||
return invoke_op(op, (input, value, *tensors)) | |||
def incr_advance_indexing(input, value, tuple_val): | |||
input, tensors, items = unpack_getitem(input, tuple_val) | |||
op = all_ops.IndexingIncrMultiAxisVec(items).to_c() | |||
op = builtin.IndexingIncrMultiAxisVec(items) | |||
return invoke_op(op, (input, value, *tensors)) | |||
def mesh_indexing(input, tuple_val): | |||
input, tensors, items = unpack_getitem(input, tuple_val) | |||
op = all_ops.MeshIndexing(items).to_c() | |||
op = builtin.MeshIndexing(items) | |||
return invoke_op(op, (input, *tensors)) | |||
def set_mesh_indexing(input, value, tuple_val): | |||
input, tensors, items = unpack_getitem(input, tuple_val) | |||
op = all_ops.SetMeshIndexing(items).to_c() | |||
op = builtin.SetMeshIndexing(items) | |||
return invoke_op(op, (input, value, *tensors)) | |||
def incr_mesh_indexing(input, value, tuple_val): | |||
input, tensors, items = unpack_getitem(input, tuple_val) | |||
op = all_ops.IncrMeshIndexing(items).to_c() | |||
op = builtin.IncrMeshIndexing(items) | |||
return invoke_op(op, (input, value, *tensors)) | |||
def batched_mesh_indexing(input, tuple_val): | |||
input, tensors, items = unpack_getitem(input, tuple_val) | |||
op = all_ops.BatchedMeshIndexing(items).to_c() | |||
op = builtin.BatchedMeshIndexing(items) | |||
return invoke_op(op, (input, *tensors)) | |||
def batched_set_mesh_indexing(input, value, tuple_val): | |||
input, tensors, items = unpack_getitem(input, tuple_val) | |||
op = all_ops.BatchedSetMeshIndexing(items).to_c() | |||
op = builtin.BatchedSetMeshIndexing(items) | |||
return invoke_op(op, (input, value, *tensors)) | |||
def batched_incr_mesh_indexing(input, value, tuple_val): | |||
input, tensors, items = unpack_getitem(input, tuple_val) | |||
op = all_ops.BatchedIncrMeshIndexing(items).to_c() | |||
op = builtin.BatchedIncrMeshIndexing(items) | |||
return invoke_op(op, (input, value, *tensors)) | |||
def test_transpose(): | |||
x = np.arange(10).reshape(2, 5).astype("int32") | |||
xx = as_raw_tensor(x) | |||
(yy,) = transpose(xx, pattern="1x0") | |||
(yy,) = transpose(xx, pattern=[1, -1, 0]) | |||
np.testing.assert_equal(np.expand_dims(x.transpose(), axis=1), yy.numpy()) | |||
@@ -1,320 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from io import StringIO | |||
import re | |||
import argparse | |||
import subprocess | |||
import os | |||
import textwrap | |||
import inspect | |||
def camel2underscore( | |||
name, *, | |||
first_cap_re=re.compile('([A-Z])([A-Z][a-z]+)'), | |||
all_cap_re = re.compile('([a-z])([A-Z]+)')): | |||
if name.isupper(): | |||
return name.lower() | |||
s1 = first_cap_re.sub(r'\1_\2', name) | |||
return all_cap_re.sub(r'\1_\2', s1).lower() | |||
def caller_lineno(level=1): | |||
f = inspect.stack()[level+1] | |||
return '%s:%d' % (f.filename, f.lineno) | |||
class Doc: | |||
"""wrap an identifier and doc""" | |||
_id = None | |||
def __init__(self, id_, doc, typestr=None, default=None): | |||
self._id = id_ | |||
self.doc = doc | |||
self.typestr = typestr | |||
self.default = default | |||
def __str__(self): | |||
return self._id | |||
class Context: | |||
fout = None | |||
def __init__(self): | |||
self.fout = StringIO() | |||
self.indent = 0 | |||
self.skipped = [] | |||
self.generated_signature = set() | |||
self.generated_opr = dict() | |||
def write(self, text, *fmt, indent=0): | |||
text = textwrap.dedent(text) | |||
text = textwrap.indent(text, ' '*4*(self.indent + indent)) | |||
text = text % fmt | |||
if not text.endswith('\n'): | |||
text += '\n' | |||
self.fout.write(text) | |||
def _gen_signature(self, params, *, have_config=True, | |||
has_out_dtype=False): | |||
sig = ['self', '*'] | |||
for i, _ in params: | |||
sig.append('{}=None'.format(i)) | |||
if have_config: | |||
sig.extend(['name=None', 'comp_node=None', 'config=None']) | |||
if has_out_dtype: | |||
sig.append('dtype=None') | |||
if params: | |||
sig.append('**kwargs') | |||
if sig[-1] == '*': | |||
sig.pop() | |||
return ', '.join(sig) | |||
def _write_canonize_inputs(self, inputs, convert_inputs, | |||
convert_inputs_args=None, | |||
has_out_dtype=False): | |||
self._write_gen_config(has_out_dtype) | |||
inputs = list(map(str, inputs)) | |||
if convert_inputs_args is None: | |||
if inputs[0][0] == '*': | |||
arg = inputs[0][1:] | |||
else: | |||
arg = '[{}]'.format(', '.join(inputs)) | |||
else: | |||
arg = convert_inputs_args | |||
self.write('inputs = helper.%s(%s, config=config)', | |||
convert_inputs, arg) | |||
def _write_gen_config(self, has_out_dtype=False): | |||
self.write('''\ | |||
config = config or Config() | |||
if name: | |||
config.name = name | |||
if comp_node: | |||
config.comp_node = comp_node | |||
''') | |||
if has_out_dtype: | |||
self.write('''\ | |||
if dtype: | |||
config.dtype = dtype | |||
''') | |||
self.write('self.config = config') | |||
def _write_make_params(self, params): | |||
for pname, ptype in params: | |||
self.write('self.%s = helper.make_param(%s, param_defs.%s, kwargs)', | |||
pname, pname, ptype) | |||
self.write('assert not kwargs, "extra kwargs: {}".format(kwargs)') | |||
def _write_doc(self, inputs, params, desc): | |||
self.write('"""') | |||
if isinstance(desc, Doc): | |||
assert desc._id is None | |||
self.write(desc.doc) | |||
elif desc: | |||
for i in textwrap.wrap(desc, 75): | |||
self.write(i) | |||
self.write('') | |||
for i in inputs: | |||
name = str(i) | |||
typestr = ':class:`.Tensor`' | |||
if name[0] == '*': | |||
name = name[1:] | |||
typestr = 'list of ' + typestr | |||
if isinstance(i, Doc): | |||
self.write(':param %s: %s', name, i.doc) | |||
if i.typestr is not None: | |||
typestr = i.typestr | |||
if typestr: | |||
if not isinstance(i, Doc): | |||
self.write(':param %s: ', name) | |||
self.write(':type %s: %s', name, typestr) | |||
for pname, ptype in params: | |||
self.write(':param %s: ', pname) | |||
self.write(':type %s: :class:`~megbrain.opr_param_defs.%s`', | |||
pname, ptype) | |||
self.write(':param comp_node: see doc for *config*') | |||
self.write(':param name: see doc for *config*') | |||
self.write( | |||
':param config: give a :class:`.OperatorNodeConfig` object to set ' | |||
'operator name and comp node. This can also be achieved by passing ' | |||
'*comp_node* and *name* separately.') | |||
self.write('"""') | |||
def _write_return(self, name, outputs): | |||
self.write('opdef = helper.PodOpVisitor("%s", config, params)', name) | |||
self.write('outputs = helper.create_op(opdef, inputs)') | |||
if outputs: | |||
self.write('outputs = [outputs[i] for i in %s]', | |||
list(map(int, outputs))) | |||
self.write('return helper.convert_outputs(outputs)') | |||
def decl_opr(self, name, *, inputs, params, desc=None, pyname=None, | |||
canonize_input_vars=None, | |||
canonize_input_vars_args=None, body=None, | |||
outputs=None, version=0, has_out_dtype=False): | |||
""" | |||
:param inputs: name of variable inputs; a name starting with `*' means | |||
a list of vars | |||
:type inputs: list of str | |||
:param params: (param name, param type) pairs; it can be a single | |||
string representing the param type, and param name defaults to | |||
'param' | |||
:type params: list of pair of str, or str | |||
:param pyname: python function name | |||
:param body: extra statements to be placed before calling _create_opr | |||
:param outputs: the indices of output vars to be selected from raw opr | |||
result | |||
""" | |||
class OprItem: | |||
def __init__(self, inputs, desc, params, version, has_out_dtype): | |||
self.inputs = inputs | |||
self.desc = desc | |||
self.params = params | |||
self.version = version | |||
self.has_out_dtype = has_out_dtype | |||
if body: | |||
self.skipped.append(name) | |||
return | |||
signature = (name, params if isinstance(params, str) else frozenset(params), has_out_dtype, version) | |||
if signature in self.generated_signature: | |||
self.skipped.append(name) | |||
return | |||
else: | |||
self.generated_signature.add(signature) | |||
body = body or [] | |||
if isinstance(params, str): | |||
params = [('param', params)] | |||
assert params | |||
if name in self.generated_opr: | |||
org_opr = self.generated_opr[name] | |||
if version > org_opr.version: | |||
def compare_doc(a, b): | |||
if isinstance(a, str): | |||
return a == b | |||
else: | |||
assert isinstance(a, Doc) | |||
return a.doc == b.doc | |||
assert compare_doc(desc, org_opr.desc) | |||
assert len(inputs) == len(org_opr.inputs) | |||
for i, j in zip(inputs, org_opr.inputs): | |||
assert compare_doc(i, j) | |||
self.generated_opr[name] = OprItem(inputs, desc, params, version, has_out_dtype) | |||
else: | |||
self.generated_opr[name] = OprItem(inputs, desc, params, version, has_out_dtype) | |||
def write_generated_oprs(self): | |||
for opr, opr_item in self.generated_opr.items(): | |||
name = opr | |||
params = opr_item.params | |||
version = opr_item.version | |||
has_out_dtype = opr_item.has_out_dtype | |||
self.write('# %s', caller_lineno()) | |||
self.write('class %s(PodOpVisitor):', name) | |||
self.indent += 1 | |||
param_names, _ = zip(*params) | |||
self.write('param_names = (%s,)', ', '.join(map('"{}"'.format, param_names))) | |||
self.write('name = "%s"', '{}V{}'.format(name, version) if version else name) | |||
self.write('\n') | |||
self.write('def __init__(%s):', | |||
self._gen_signature(params, | |||
has_out_dtype=has_out_dtype)) | |||
self.indent += 1 | |||
self._write_gen_config(has_out_dtype=has_out_dtype) | |||
self.write('\n') | |||
self._write_make_params(params) | |||
self.write('\n') | |||
self.indent -= 2 | |||
def decl_raw_opr(self, name, *, inputs, inputs_cvt=[], body=None, | |||
desc=None, local_defs=[], have_config=True, params=None, has_out_dtype=False): | |||
self.skipped.append(name) | |||
def get_str(self): | |||
return self.fout.getvalue() | |||
def all_list(self): | |||
buf = StringIO() | |||
print( | |||
'[', | |||
*(' "%s",' % i for i in self.generated_opr), | |||
']', | |||
sep='\n', | |||
file=buf | |||
) | |||
return buf.getvalue() | |||
def main(): | |||
parser = argparse.ArgumentParser( | |||
description='generate operator function def code from decl file') | |||
parser.add_argument('inputs', nargs='+') | |||
parser.add_argument('--output', '-o') | |||
args = parser.parse_args() | |||
gen = Context() | |||
exec_globals = { | |||
'decl_opr': gen.decl_opr, | |||
'decl_raw_opr': gen.decl_raw_opr, | |||
'Doc': Doc, | |||
'camel2underscore': camel2underscore, | |||
} | |||
for i in args.inputs: | |||
print('generate ops from {}'.format(i)) | |||
with open(i) as fin: | |||
exec(compile(fin.read(), i, 'exec'), exec_globals) | |||
gen.write_generated_oprs() | |||
try: | |||
git_commit = subprocess.check_output( | |||
['git', 'rev-parse', 'HEAD'], universal_newlines=True, | |||
cwd=os.path.dirname(os.path.realpath(__file__))).strip() | |||
except: | |||
git_commit = 'NOT_A_GIT_REPO' | |||
def relpath(*args): | |||
d = os.path.dirname(__file__) | |||
return os.path.join(d, *args) | |||
with open(relpath('ops.tpl.py')) as fin: | |||
with open(args.output, 'w') as fout: | |||
fout.write(fin.read() | |||
.replace('{%all%}', gen.all_list()) | |||
.replace('{%body%}', gen.get_str()) | |||
.replace('{%git_commit%}', git_commit)) | |||
print('Skipped:') | |||
print(*gen.skipped, sep='\n') | |||
if __name__ == '__main__': | |||
main() |
@@ -1,40 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
"""This python module contains functions to apply the operators defined by | |||
megbrain. | |||
.. note:: | |||
Most of the functions are automatically generated, and their signature have | |||
the form contain a ``param`` argument (or more than one arguments such as | |||
:func:`convolution` that has ``param`` and ``execution_polity``) and also | |||
accept keyword arguments. In such case, it can be called by either | |||
providing a param object of appropriate type, or by passing the arguments | |||
needed by the constructor of param object to the keyword arguments. | |||
Furthermore, for a param that needs an enumeration member, the enum name | |||
can be used to refer to the enum object. | |||
For example, the following statements are equivalent:: | |||
elemwise([a, b], mode='max') | |||
elemwise([a, b], mode=opr_param_defs.Elemwise.Mode.MAX) | |||
elemwise([a, b], param=opr_param_defs.Elemwise('max')) | |||
""" | |||
__git_commit__ = "{%git_commit%}" | |||
import collections | |||
from . import helper | |||
from .helper import PodOpVisitor | |||
from . import param_defs | |||
from ..._imperative_rt import OperatorNodeConfig as Config | |||
__all__ = {%all%} | |||
{%body%} |
@@ -36,7 +36,7 @@ SmallVector<TensorPtr> OpDef::apply_on_physical_tensor( | |||
return def.trait()->apply_on_physical_tensor(def, inputs); | |||
} | |||
cg::OperatorNodeBase* OpDef::apply_on_var_node( | |||
VarNodeArray OpDef::apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
return def.trait()->apply_on_var_node(def, inputs); | |||
@@ -56,6 +56,14 @@ BackwardGraphResult OpDef::make_backward_graph( | |||
return def.trait()->make_backward_graph(def, inputs, input_requires_grad, output_has_grad); | |||
} | |||
size_t OpDef::hash() const { | |||
return trait()->hash(*this); | |||
} | |||
bool OpDef::is_same_st(const Hashable& rhs) const { | |||
return trait()->is_same_st(*this, static_cast<const OpDef&>(rhs)); | |||
} | |||
const OpTrait* OpDef::trait() const { | |||
if (!m_trait) { | |||
m_trait = OpTrait::find_by_typeinfo(dyn_typeinfo()); | |||
@@ -23,7 +23,7 @@ namespace detail { | |||
struct StaticData { | |||
std::list<OpTrait> registries; | |||
std::unordered_map<const char*, OpTrait*> name2reg; | |||
std::unordered_map<std::string, OpTrait*> name2reg; | |||
std::unordered_map<Typeinfo*, OpTrait*> type2reg; | |||
}; | |||
@@ -30,6 +30,32 @@ struct OpMeth<RType(Args...)>: public thin_function<RType(Args...)> { | |||
return this->Base::operator ()(args...); | |||
} | |||
}; | |||
template<typename T> | |||
struct ToVarNodeArray: std::false_type {}; | |||
template<> | |||
struct ToVarNodeArray<SymbolVar>: std::true_type { | |||
VarNodeArray operator()(const SymbolVar& inp) { | |||
return {inp.node()}; | |||
} | |||
}; | |||
template<> | |||
struct ToVarNodeArray<SymbolVarArray>: std::true_type { | |||
VarNodeArray operator()(const SymbolVarArray& inputs) { | |||
return cg::to_var_node_array(inputs); | |||
} | |||
}; | |||
template<size_t N> | |||
struct ToVarNodeArray<std::array<SymbolVar, N>>: std::true_type { | |||
VarNodeArray operator()(const std::array<SymbolVar, N>& inp) { | |||
return cg::to_var_node_array({inp.begin(), inp.end()}); | |||
} | |||
}; | |||
template<> | |||
struct ToVarNodeArray<cg::OperatorNodeBase*>: std::true_type { | |||
VarNodeArray operator()(const cg::OperatorNodeBase* opr) { | |||
return opr->usable_output(); | |||
} | |||
}; | |||
} // detail | |||
using OpDefMaker = detail::OpMeth< | |||
@@ -42,6 +68,8 @@ using InferOutputAttrsFallible = detail::OpMeth< | |||
decltype(OpDef::infer_output_attrs_fallible)>; | |||
using GradMaker = detail::OpMeth< | |||
decltype(OpDef::make_backward_graph)>; | |||
using HashFunc = detail::OpMeth<size_t(const OpDef&)>; | |||
using IsSame = detail::OpMeth<bool(const OpDef&, const OpDef&)>; | |||
struct OpTrait { | |||
const char* name; | |||
@@ -50,6 +78,8 @@ struct OpTrait { | |||
ApplyOnVarNode apply_on_var_node; | |||
InferOutputAttrsFallible infer_output_attrs_fallible; | |||
GradMaker make_backward_graph; | |||
HashFunc hash; | |||
IsSame is_same_st; | |||
OpTrait(const char* name); | |||
static OpTrait* find_by_name(const char* name); | |||
static OpTrait* find_by_typeinfo(Typeinfo* type); | |||
@@ -61,7 +91,9 @@ struct OpTrait { | |||
cb(apply_on_physical_tensor) \ | |||
cb(apply_on_var_node) \ | |||
cb(infer_output_attrs_fallible) \ | |||
cb(make_backward_graph) | |||
cb(make_backward_graph) \ | |||
cb(hash) \ | |||
cb(is_same_st) | |||
struct OpTraitRegistry { | |||
OpTrait* trait; | |||
@@ -97,6 +129,15 @@ struct OpTraitRegistry { | |||
void do_insert(Typeinfo* type); | |||
static OpTraitRegistry do_insert(const char* name); | |||
template<typename T, | |||
typename To = detail::ToVarNodeArray<T>, | |||
typename = std::enable_if_t<To::value>> | |||
OpTraitRegistry& apply_on_var_node(T (*f)(const OpDef&, const VarNodeArray&)) { | |||
return apply_on_var_node([=](const OpDef& opdef, const VarNodeArray& inputs) { | |||
return To()(f(opdef, inputs)); | |||
}); | |||
} | |||
}; | |||
} // namespace imperative | |||
@@ -0,0 +1,46 @@ | |||
/** | |||
* \file imperative/src/impl/ops/autogen.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "../op_trait.h" | |||
using namespace megdnn; | |||
// FIXME: remove this when mgb::hash support tuple_hash | |||
namespace mgb { | |||
namespace { | |||
template<typename T, size_t ...Ns> | |||
auto tail(T t, std::index_sequence<Ns...>) { | |||
return std::make_tuple(std::get<Ns+1>(t)...); | |||
} | |||
} // anonymous namespace | |||
template<typename T, typename ...Args> | |||
class HashTrait<std::tuple<T, Args...>> { | |||
constexpr static size_t length = sizeof...(Args); | |||
public: | |||
static size_t eval(const std::tuple<T, Args...> &t) { | |||
const T& val = std::get<0>(t); | |||
if constexpr (!length) { | |||
return mgb::hash(val); | |||
} else { | |||
return mgb::hash_pair_combine(mgb::hash(val), | |||
mgb::hash(tail(t, std::make_index_sequence<length - 1>{}))); | |||
} | |||
} | |||
}; | |||
} // namespace mgb | |||
namespace mgb::imperative { | |||
#include "./opdef.cpp.inl" | |||
} // namespace mgb::imperative |
@@ -9,7 +9,8 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/imperative/ops/batch_norm.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/opr/dnn/batch_norm.h" | |||
#include "../op_trait.h" | |||
namespace mgb { | |||
@@ -19,9 +20,7 @@ namespace { | |||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
auto* node = &node_->cast_final_safe<opr::BatchNorm>(); | |||
auto&& param = node->param(); | |||
return BatchNorm::make(param.param_dim, param.fwd_mode, param.epsilon, | |||
param.avg_factor, param.scale, param.bias); | |||
return BatchNorm::make(node->param()); | |||
} | |||
cg::OperatorNodeBase* apply_on_var_node( | |||
@@ -33,13 +32,11 @@ cg::OperatorNodeBase* apply_on_var_node( | |||
"BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp); | |||
if (nr_inp == 3) { | |||
return opr::BatchNorm::make( | |||
inputs[0], inputs[1], inputs[2], | |||
{bn_opr.param_dim, bn_opr.fwd_mode, bn_opr.epsilon, bn_opr.avg_factor, bn_opr.scale, bn_opr.bias})[0] | |||
inputs[0], inputs[1], inputs[2], bn_opr.param())[0] | |||
.node()->owner_opr(); | |||
} else { | |||
return opr::BatchNorm::make( | |||
inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], | |||
{bn_opr.param_dim, bn_opr.fwd_mode, bn_opr.epsilon, bn_opr.avg_factor, bn_opr.scale, bn_opr.bias})[0] | |||
inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], bn_opr.param())[0] | |||
.node()->owner_opr(); | |||
} | |||
} | |||
@@ -52,7 +49,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
mgb_assert(nr_inp == 3 ||nr_inp == 5, | |||
"BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp); | |||
// need running mean/variance | |||
bool need_stat = (nr_inp == 5) && op_def.fwd_mode == BatchNorm::Param::FwdMode::TRAINING; | |||
bool need_stat = (nr_inp == 5) && op_def.fwd_mode == BatchNorm::FwdMode::TRAINING; | |||
size_t nr_out = need_stat? 5 : 3; | |||
SmallVector<LogicalTensorDesc> out_shapes(nr_out); | |||
auto&& i0 = inputs[0]; | |||
@@ -76,8 +73,6 @@ OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm) | |||
.fallback(); | |||
} // anonymous namespace | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchNorm); | |||
} // namespace imperative | |||
} // namespace mgb | |||
@@ -9,7 +9,9 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/imperative/ops/broadcast.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/opr/tensor_manip.h" | |||
#include "../op_trait.h" | |||
namespace mgb { | |||
@@ -87,8 +89,6 @@ OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) | |||
.fallback(); | |||
} // anonymous namespace | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Broadcast); | |||
} // namespace imperative | |||
} // namespace mgb | |||
@@ -18,7 +18,7 @@ | |||
#include "megbrain/utils/hash.h" | |||
#endif // MGB_ENABLE_OPR_MM | |||
#include "megbrain/imperative/ops/collective_comm.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
namespace mgb { | |||
namespace imperative { | |||
@@ -61,8 +61,8 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node) { | |||
auto [addr, port] = split_address(group_client->get_addr()); | |||
auto comp_node = node->config().get_single_comp_node().to_string_logical(); | |||
return std::make_shared<CollectiveComm>( | |||
comm.key(), comm.nr_devices(), comm.rank(), comm.is_root(), | |||
comm.local_grad(), addr, std::stoi(port), comm.param().mode, | |||
comm.param().mode, comm.key(), comm.nr_devices(), comm.rank(), | |||
comm.is_root(), comm.local_grad(), addr, std::stoi(port), | |||
comm.dtype(), comm.backend(), comp_node); | |||
} | |||
@@ -73,35 +73,6 @@ OP_TRAIT_REG(CollectiveComm, CollectiveComm, opr::CollectiveComm) | |||
} // anonymous namespace | |||
#endif // MGB_ENABLE_OPR_MM | |||
bool CollectiveComm::is_same_st(const Hashable& another) const{ | |||
auto* comm_opr = another.try_cast_final<CollectiveComm>(); | |||
if(!comm_opr){ | |||
return false; | |||
} | |||
return as_tuple() == comm_opr->as_tuple(); | |||
} | |||
size_t CollectiveComm::hash() const{ | |||
XXHash xxhash{}; | |||
auto append = [&xxhash](auto field){ | |||
auto hash_val = HashTrait<decltype(field)>::eval(field); | |||
xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val)); | |||
}; | |||
append(key); | |||
append(nr_devices); | |||
append(rank); | |||
append(is_root); | |||
append(local_grad); | |||
append(addr); | |||
append(port); | |||
append(mode); | |||
append(backend); | |||
append(comp_node); | |||
return xxhash.digest(); | |||
} | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CollectiveComm); | |||
} // namespace imperative | |||
} // namespace mgb | |||
@@ -9,8 +9,7 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/imperative/ops/cond_take.h" | |||
#include "megbrain/imperative/ops/opr_attr.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/opr/misc.h" | |||
#include "../dnn_op_helper.h" | |||
#include "../op_trait.h" | |||
@@ -19,8 +18,6 @@ using namespace megdnn; | |||
namespace mgb::imperative { | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondTake); | |||
namespace { | |||
class MegDNNDynOutMallocImpl final: public megdnn::DynOutMallocPolicy { | |||
@@ -9,7 +9,9 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/imperative/ops/elemwise.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/opr/basic_arith.h" | |||
#include "../op_trait.h" | |||
namespace mgb { | |||
@@ -33,7 +35,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
const OpDef& def, | |||
const SmallVector<LogicalTensorDesc>& inputs) { | |||
auto&& op_def = def.cast_final_safe<Elemwise>(); | |||
auto trait = Elemwise::ModeTrait::from_mode(op_def.mode); | |||
auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode); | |||
mgb_assert(inputs.size() == trait.arity, | |||
"%s expects %u inputs; got %zu actually", trait.name, | |||
trait.arity, inputs.size()); | |||
@@ -70,8 +72,6 @@ OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise) | |||
.fallback(); | |||
} // anonymous namespace | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Elemwise); | |||
} // namespace imperative | |||
} // namespace mgb | |||
@@ -18,7 +18,7 @@ | |||
#include "megbrain/opr/mm_handler.h" | |||
#endif // MGB_ENABLE_OPR_MM | |||
#include "megbrain/imperative/ops/io_remote.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
namespace mgb { | |||
namespace imperative { | |||
@@ -60,45 +60,5 @@ OP_TRAIT_REG(RemoteRecv, RemoteRecv, mgb::opr::RemoteRecv) | |||
} // anonymous namespace | |||
#endif // MGB_ENABLE_OPR_MM | |||
bool RemoteSend::is_same_st(const Hashable& another) const{ | |||
return as_tuple() == another.cast_final<RemoteSend>().as_tuple(); | |||
} | |||
size_t RemoteSend::hash() const{ | |||
XXHash xxhash; | |||
auto append = [&xxhash](auto field){ | |||
auto hash_val = HashTrait<decltype(field)>::eval(field); | |||
xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val)); | |||
}; | |||
append(key); | |||
append(addr); | |||
append(port); | |||
append(rank_to); | |||
return xxhash.digest(); | |||
} | |||
bool RemoteRecv::is_same_st(const Hashable& another) const{ | |||
return as_tuple() == another.cast_final<RemoteRecv>().as_tuple(); | |||
} | |||
size_t RemoteRecv::hash() const{ | |||
XXHash xxhash; | |||
auto append = [&xxhash](auto field){ | |||
auto hash_val = HashTrait<decltype(field)>::eval(field); | |||
xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val)); | |||
}; | |||
append(key); | |||
append(addr); | |||
append(port); | |||
append(rank_from); | |||
append(cn.to_string()); | |||
append(dtype.handle()); | |||
append(shape.to_string()); | |||
return xxhash.digest(); | |||
} | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend); | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv); | |||
} // namespace imperative | |||
} // namespace mgb |
@@ -11,7 +11,7 @@ | |||
#include "../op_trait.h" | |||
#include "megbrain/imperative/ops/nms.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/opr/standalone/nms_opr.h" | |||
namespace mgb { | |||
@@ -37,8 +37,6 @@ OP_TRAIT_REG(NMSKeep, NMSKeep, NMSKeepOpr) | |||
.fallback(); | |||
} // anonymous namespace | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(NMSKeep); | |||
} // namespace imperative | |||
} // namespace mgb | |||
@@ -0,0 +1,630 @@ | |||
/** | |||
* \file imperative/src/impl/ops/autogen.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// FIXME: split this file into separate files for each specialized op | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/opr/dnn/convolution.h" | |||
#include "megbrain/opr/dnn/adaptive_pooling.h" | |||
#include "megbrain/opr/dnn/fake_quant.h" | |||
#include "megbrain/opr/dnn/pooling.h" | |||
#include "megbrain/opr/dnn/local.h" | |||
#include "megbrain/opr/dnn/roi_align.h" | |||
#include "megbrain/opr/dnn/roi_pooling.h" | |||
#include "megbrain/opr/basic_arith.h" | |||
#include "megbrain/opr/blas.h" | |||
#include "megbrain/opr/imgproc.h" | |||
#include "megbrain/opr/indexing.h" | |||
#include "megbrain/opr/io.h" | |||
#include "megbrain/opr/misc.h" | |||
#include "megbrain/opr/nn_int.h" | |||
#include "megbrain/opr/rand.h" | |||
#include "megbrain/opr/tensor_gen.h" | |||
#include "megbrain/opr/tensor_manip.h" | |||
#include "megbrain/opr/utility.h" | |||
#include "../op_trait.h" | |||
namespace mgb::imperative { | |||
namespace { namespace convolution { | |||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
auto* node = &node_->cast_final_safe<opr::Convolution>(); | |||
return Convolution::make(node->param(), node->execution_policy()); | |||
} | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& conv = static_cast<const Convolution&>(def); | |||
return opr::Convolution::make(inputs[0], inputs[1], conv.param(), conv.policy()); | |||
} | |||
OP_TRAIT_REG(Convolution, Convolution, opr::Convolution) | |||
.make_from_op_node(make_from_op_node) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // convolution | |||
namespace { namespace convolution_backward_data { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& conv = static_cast<const ConvolutionBackwardData&>(def); | |||
cg::OperatorNodeConfig config; | |||
if (inputs.size() == 2) { | |||
return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||
} else { | |||
mgb_assert(inputs.size() == 3); | |||
return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config); | |||
} | |||
} | |||
OP_TRAIT_REG(ConvolutionBackwardData, ConvolutionBackwardData) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // convolution_backward_data | |||
namespace { namespace dimshuffle { | |||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
auto* node = &node_->cast_final_safe<opr::Dimshuffle>(); | |||
std::vector<int> pattern(node->param().pattern_len); | |||
for (size_t i = 0; i < node->param().pattern_len; ++ i) { | |||
pattern[i] = node->param().pattern[i]; | |||
} | |||
return Dimshuffle::make(pattern); | |||
} | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& ds = static_cast<const Dimshuffle&>(def); | |||
return opr::Dimshuffle::make(inputs[0], ds.pattern); | |||
} | |||
OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle) | |||
.make_from_op_node(make_from_op_node) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // dimshuffle | |||
namespace { namespace add_axis { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& add_axis = static_cast<const AddAxis&>(def); | |||
using Desc = opr::AxisAddRemove::AxisDesc; | |||
std::vector<Desc> param; | |||
for (auto&& i : add_axis.axis) { | |||
param.push_back(Desc::make_add(i)); | |||
} | |||
return opr::AxisAddRemove::make(inputs[0], param); | |||
} | |||
OP_TRAIT_REG(AddAxis, AddAxis) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // add_axis | |||
namespace { namespace remove_axis { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& remove_axis = static_cast<const RemoveAxis&>(def); | |||
using Desc = opr::AxisAddRemove::AxisDesc; | |||
std::vector<Desc> param; | |||
for (auto&& i : remove_axis.axis) { | |||
param.push_back(Desc::make_remove(i)); | |||
} | |||
return opr::AxisAddRemove::make(inputs[0], param); | |||
} | |||
OP_TRAIT_REG(RemoveAxis, RemoveAxis) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // remove_axis | |||
namespace { namespace top_k { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& topk = static_cast<const TopK&>(def); | |||
return opr::TopK::make(inputs[0], inputs[1], topk.param())[0] | |||
.node()->owner_opr(); | |||
} | |||
OP_TRAIT_REG(TopK, TopK) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // top_k | |||
namespace { namespace reduce { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& reduce = static_cast<const Reduce&>(def); | |||
if (inputs.size() > 1) { | |||
return opr::Reduce::make(inputs[0], reduce.param(), inputs[1]); | |||
} else { | |||
return opr::Reduce::make(inputs[0], reduce.param()); | |||
} | |||
} | |||
OP_TRAIT_REG(Reduce, Reduce) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // reduce | |||
namespace { namespace adaptive_pooling { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& pool = static_cast<const AdaptivePooling&>(def); | |||
return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param()); | |||
} | |||
OP_TRAIT_REG(AdaptivePooling, AdaptivePooling) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // adaptive_pooling | |||
namespace { namespace conv_bias { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& conv = static_cast<const ConvBias&>(def); | |||
cg::OperatorNodeConfig config{conv.dtype}; | |||
if (inputs.size() == 2) { | |||
return opr::ConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||
} else if (inputs.size() == 3) { | |||
return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config); | |||
} else if (inputs.size() == 4) { | |||
return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(), config); | |||
} | |||
mgb_assert(0); | |||
} | |||
OP_TRAIT_REG(ConvBias, ConvBias) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // conv_bias | |||
namespace { namespace batch_conv_bias { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& conv = static_cast<const BatchConvBias&>(def); | |||
cg::OperatorNodeConfig config{conv.dtype}; | |||
if (inputs.size() == 2) { | |||
return opr::BatchConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||
} else if (inputs.size() == 3) { | |||
return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config); | |||
} else if (inputs.size() == 4) { | |||
return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(), config); | |||
} | |||
mgb_assert(0); | |||
} | |||
OP_TRAIT_REG(BatchConvBias, BatchConvBias) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // batch_conv_bias | |||
namespace { namespace pooling { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& pool = static_cast<const Pooling&>(def); | |||
return opr::Pooling::make(inputs[0], pool.param()); | |||
} | |||
OP_TRAIT_REG(Pooling, Pooling) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // pooling | |||
namespace { namespace matrix_mul { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& matmul = static_cast<const MatrixMul&>(def); | |||
mgb_assert(inputs.size() == 2); | |||
return opr::MatrixMul::make(inputs[0], inputs[1], matmul.param()); | |||
} | |||
OP_TRAIT_REG(MatrixMul, MatrixMul) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // matrix_mul | |||
namespace { namespace batched_matrix_mul { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& matmul = static_cast<const BatchedMatrixMul&>(def); | |||
mgb_assert(inputs.size() == 2); | |||
return opr::BatchedMatrixMul::make(inputs[0], inputs[1], matmul.param()); | |||
} | |||
OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // batched_matrix_mul | |||
namespace { namespace dot { | |||
auto apply_on_var_node( | |||
const OpDef&, | |||
const VarNodeArray& inputs) { | |||
mgb_assert(inputs.size() == 2); | |||
return opr::Dot::make(inputs[0], inputs[1]); | |||
} | |||
OP_TRAIT_REG(Dot, Dot) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // dot | |||
namespace { namespace argsort { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& argsort = static_cast<const Argsort&>(def); | |||
return opr::Argsort::make(inputs[0], argsort.param()); | |||
} | |||
OP_TRAIT_REG(Argsort, Argsort) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // argsort | |||
namespace { namespace argmax { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& argmax = static_cast<const Argmax&>(def); | |||
return opr::Argmax::make(inputs[0], argmax.param()); | |||
} | |||
OP_TRAIT_REG(Argmax, Argmax) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // argmax | |||
namespace { namespace argmin { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& argmin = static_cast<const Argmin&>(def); | |||
return opr::Argmin::make(inputs[0], argmin.param()); | |||
} | |||
OP_TRAIT_REG(Argmin, Argmin) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // argmin | |||
namespace { namespace warp_perspective { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& warp = static_cast<const WarpPerspective&>(def); | |||
if (inputs.size() == 3) { | |||
return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], warp.param()); | |||
} else { | |||
mgb_assert(inputs.size() == 4); | |||
return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], inputs[3], warp.param()); | |||
} | |||
} | |||
OP_TRAIT_REG(WarpPerspective, WarpPerspective) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // warp_perspective | |||
namespace { namespace group_local { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& local = static_cast<const GroupLocal&>(def); | |||
mgb_assert(inputs.size() == 2); | |||
return opr::GroupLocal::make(inputs[0], inputs[1], local.param()); | |||
} | |||
OP_TRAIT_REG(GroupLocal, GroupLocal) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // group_local | |||
namespace { namespace indexing_one_hot { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const IndexingOneHot&>(def); | |||
mgb_assert(inputs.size() == 2); | |||
return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param()); | |||
} | |||
OP_TRAIT_REG(IndexingOneHot, IndexingOneHot) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // indexing_one_hot | |||
namespace { namespace indexing_set_one_hot { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const IndexingSetOneHot&>(def); | |||
mgb_assert(inputs.size() == 3); | |||
return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2], op.param()); | |||
} | |||
OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // indexing_set_one_hot | |||
namespace { namespace typecvt { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const TypeCvt&>(def); | |||
mgb_assert(inputs.size() == 1); | |||
return opr::TypeCvt::make(inputs[0], op.dtype); | |||
} | |||
OP_TRAIT_REG(TypeCvt, TypeCvt) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // typecvt | |||
namespace { namespace concat { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const Concat&>(def); | |||
cg::OperatorNodeConfig config{op.comp_node}; | |||
return opr::Concat::make(inputs, op.axis, config); | |||
} | |||
OP_TRAIT_REG(Concat, Concat) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // concat | |||
namespace { namespace copy { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const Copy&>(def); | |||
mgb_assert(inputs.size() == 1); | |||
cg::OperatorNodeConfig config{op.comp_node}; | |||
return opr::Copy::make(inputs[0], config); | |||
} | |||
OP_TRAIT_REG(Copy, Copy) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // copy | |||
namespace { namespace identity { | |||
auto apply_on_var_node( | |||
const OpDef&, | |||
const VarNodeArray& inputs) { | |||
mgb_assert(inputs.size() == 1); | |||
return opr::Identity::make(inputs[0]); | |||
} | |||
OP_TRAIT_REG(Identity, Identity) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // identity | |||
namespace { namespace uniform_rng { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const UniformRNG&>(def); | |||
mgb_assert(inputs.size() == 1); | |||
return opr::UniformRNG::make(inputs[0], op.param()); | |||
} | |||
OP_TRAIT_REG(UniformRNG, UniformRNG) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // uniform_rng | |||
namespace { namespace gaussian_rng { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const GaussianRNG&>(def); | |||
mgb_assert(inputs.size() == 1); | |||
return opr::GaussianRNG::make(inputs[0], op.param()); | |||
} | |||
OP_TRAIT_REG(GaussianRNG, GaussianRNG) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // gaussian_rng | |||
namespace { namespace roi_align { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const ROIAlign&>(def); | |||
mgb_assert(inputs.size() == 2); | |||
return opr::ROIAlign::make(inputs[0], inputs[1], op.param()); | |||
} | |||
OP_TRAIT_REG(ROIAlign, ROIAlign) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // roi_align | |||
#if MGB_CUDA | |||
namespace { namespace nvof { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const NvOf&>(def); | |||
mgb_assert(inputs.size() == 1); | |||
return opr::NvOf::make(inputs[0], op.param()); | |||
} | |||
OP_TRAIT_REG(NvOf, NvOf) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // nvof | |||
#endif | |||
namespace { namespace linspace { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const Linspace&>(def); | |||
mgb_assert(inputs.size() == 3); | |||
cg::OperatorNodeConfig config{op.comp_node}; | |||
return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(), config); | |||
} | |||
OP_TRAIT_REG(Linspace, Linspace) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // linspace | |||
namespace { namespace eye { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const Eye&>(def); | |||
mgb_assert(inputs.size() == 1); | |||
cg::OperatorNodeConfig config{op.comp_node}; | |||
opr::Eye::Param param{op.k, op.dtype.enumv()}; | |||
return opr::Eye::make(inputs[0], param, config); | |||
} | |||
OP_TRAIT_REG(Eye, Eye) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // eye | |||
namespace { namespace roi_pooling { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const ROIPooling&>(def); | |||
mgb_assert(inputs.size() == 3); | |||
return opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param()); | |||
} | |||
OP_TRAIT_REG(ROIPooling, ROIPooling) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // roi_pooling | |||
namespace { namespace remap { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const Remap&>(def); | |||
mgb_assert(inputs.size() == 2); | |||
return opr::Remap::make(inputs[0], inputs[1], op.param()); | |||
} | |||
OP_TRAIT_REG(Remap, Remap) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // remap | |||
namespace { namespace reshape { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const Reshape&>(def); | |||
mgb_assert(inputs.size() == 2); | |||
return opr::Reshape::make(inputs[0], inputs[1], op.param()); | |||
} | |||
OP_TRAIT_REG(Reshape, Reshape) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // reshape | |||
namespace { | |||
auto get_index( | |||
const VarNodeArray& inputs, size_t vidx, | |||
const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& mask) { | |||
size_t length = mask.size(); | |||
opr::Subtensor::IndexDesc ret(length); | |||
for (size_t i = 0; i < length; ++ i) { | |||
auto&& [axis, begin, end, step, idx] = mask[i]; | |||
ret[i].axis = axis; | |||
if (idx) { | |||
ret[i].idx = inputs[vidx++]; | |||
} else { | |||
mgb_assert(begin || end || step); | |||
if (begin) ret[i].begin = inputs[vidx++]; | |||
if (end) ret[i].end = inputs[vidx++]; | |||
if (step) ret[i].step = inputs[vidx++]; | |||
} | |||
} | |||
mgb_assert(vidx == inputs.size()); | |||
return ret; | |||
} | |||
#define IN1 inputs[0] | |||
#define IN2 inputs[0], inputs[1] | |||
#define FANCY_INDEXING_IMPL(NAME, NR_INPUT) \ | |||
namespace NAME##_impl { \ | |||
auto apply_on_var_node( \ | |||
const OpDef& def, \ | |||
const VarNodeArray& inputs) { \ | |||
auto&& op = static_cast<const NAME&>(def); \ | |||
return opr::NAME::make(IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items)); \ | |||
} \ | |||
OP_TRAIT_REG(NAME, NAME) \ | |||
.apply_on_var_node(apply_on_var_node) \ | |||
.fallback(); \ | |||
} | |||
FANCY_INDEXING_IMPL(Subtensor, 1) | |||
FANCY_INDEXING_IMPL(SetSubtensor, 2) | |||
FANCY_INDEXING_IMPL(IncrSubtensor, 2) | |||
FANCY_INDEXING_IMPL(IndexingMultiAxisVec, 1) | |||
FANCY_INDEXING_IMPL(IndexingSetMultiAxisVec, 2) | |||
FANCY_INDEXING_IMPL(IndexingIncrMultiAxisVec, 2) | |||
FANCY_INDEXING_IMPL(MeshIndexing, 1) | |||
FANCY_INDEXING_IMPL(IncrMeshIndexing, 2) | |||
FANCY_INDEXING_IMPL(SetMeshIndexing, 2) | |||
FANCY_INDEXING_IMPL(BatchedMeshIndexing, 1) | |||
FANCY_INDEXING_IMPL(BatchedIncrMeshIndexing, 2) | |||
FANCY_INDEXING_IMPL(BatchedSetMeshIndexing, 2) | |||
#undef FANCY_INDEXING_IMPL | |||
#undef IN1 | |||
#undef IN2 | |||
} // anonymous namespace | |||
namespace { namespace fake_quant { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const FakeQuant&>(def); | |||
mgb_assert(inputs.size() == 3); | |||
return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param()); | |||
} | |||
OP_TRAIT_REG(FakeQuant, FakeQuant) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // fake_quant | |||
namespace { namespace elemwise_multi_type { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const ElemwiseMultiType&>(def); | |||
OperatorNodeConfig config{op.dtype}; | |||
return opr::ElemwiseMultiType::make(inputs, op.param(), config); | |||
} | |||
OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // fake_quant | |||
namespace { namespace svd { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const SVD&>(def); | |||
mgb_assert(inputs.size() == 1); | |||
return opr::SVD::make(inputs[0], op.param()); | |||
} | |||
OP_TRAIT_REG(SVD, SVD) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // svd | |||
} // namespace mgb::imperative |
@@ -9,7 +9,7 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/imperative/ops/tensor_manip.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/imperative/ops/opr_attr.h" | |||
#include "megbrain/opr/tensor_manip.h" | |||
#include "../op_trait.h" | |||
@@ -140,8 +140,4 @@ OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat) | |||
.fallback(); | |||
} // namespace | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(GetVarShape); | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ParamPackSplit); | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ParamPackConcat); | |||
} // namespace mgb::imperative |
@@ -130,7 +130,7 @@ void Profiler::start(uint32_t flags) { | |||
// TODO: assign parent | |||
entry.parent = 0; | |||
// Record apply context and save to m_profile | |||
entry.op = def.copy(); | |||
entry.op = const_cast<OpDef&>(def).shared_from_this(); | |||
for (auto&& input : inputs) { | |||
entry.inputs.push_back({m_tensor_recorder.record_tensor(input), | |||
shape2vector(input->layout()), | |||
@@ -172,31 +172,31 @@ void Profiler::start(uint32_t flags) { | |||
if (flags & PROFILE_FOOTPRINT) { | |||
hook_apply_on_var_node->apply_hook( | |||
[this](auto&& apply, const OpDef& def, | |||
VarNodeArray inputs) -> cg::OperatorNodeBase* { | |||
auto* operator_node = apply(def, std::move(inputs)); | |||
VarNodeArray inputs) -> VarNodeArray { | |||
auto vars = apply(def, std::move(inputs)); | |||
std::remove_reference_t<decltype(m_entry_stack.top())> | |||
top; | |||
{ | |||
MGB_LOCK_GUARD(m_lock); | |||
if (m_entry_stack.empty()) { | |||
return operator_node; | |||
return vars; | |||
} | |||
top = m_entry_stack.top(); | |||
} | |||
auto [current_op, current_entry, thread_id] = top; | |||
if (current_op != &def || | |||
thread_id != std::this_thread::get_id()) { | |||
return operator_node; | |||
return vars; | |||
} | |||
auto&& footprint_result = | |||
footprint.calc_footprint(operator_node); | |||
footprint.calc_footprint(vars[0]->owner_opr()); | |||
current_entry->memory = footprint_result.memory; | |||
current_entry->computation = | |||
footprint_result.computation; | |||
#if MGB_ENABLE_JSON | |||
current_entry->param = footprint_result.param; | |||
#endif | |||
return operator_node; | |||
return vars; | |||
}); | |||
} | |||
m_hooker_list.push_back(std::move(hook_apply_on_physical_tensor)); | |||
@@ -590,7 +590,7 @@ cg::OperatorNodeBase* ProxyGraph::get_proxy_opr( | |||
for (size_t i = 0; i < inputs.size(); ++ i) { | |||
vinputs[i] = InputPlaceholder::make(*m_graph, *inputs[i]).node(); | |||
} | |||
auto opr = OpDef::apply_on_var_node(opdef, vinputs); | |||
auto opr = OpDef::apply_on_var_node(opdef, vinputs)[0]->owner_opr(); | |||
mgb_assert(!opr->same_type<InputPlaceholder>()); | |||
for (auto &&i : opr->input()) { | |||
mgb_assert(i->owner_opr()->same_type<InputPlaceholder>()); | |||
@@ -639,7 +639,7 @@ ProxyGraph::make_backward_graph( | |||
return ret.first->second; | |||
}; | |||
auto inputs = make_input_place_holders(input_descs); | |||
auto fwd = OpDef::apply_on_var_node(opdef, inputs); | |||
auto fwd = OpDef::apply_on_var_node(opdef, inputs)[0]->owner_opr(); | |||
auto&& outputs = fwd->usable_output(); | |||
SmallVector<LogicalTensorDesc> output_descs; | |||
for (auto&& i : outputs) { | |||
@@ -799,7 +799,7 @@ cg::OperatorNodeBase* ProxyGraph::get_proxy_opr(const OpDef& opdef, | |||
const SmallVector<LogicalTensorDesc>& inputs) { | |||
mgb_assert(!m_cur_opr); | |||
auto vinputs = make_input_place_holders(inputs); | |||
return OpDef::apply_on_var_node(opdef, vinputs); | |||
return OpDef::apply_on_var_node(opdef, vinputs)[0]->owner_opr(); | |||
} | |||
VarNodeArray ProxyGraph::make_input_place_holders(const SmallVector<LogicalTensorDesc>& inputs) { | |||
@@ -26,13 +26,12 @@ struct BackwardGraphResult { | |||
std::vector<bool> input_has_grad; | |||
}; | |||
class OpDef : public Hashable { | |||
class OpDef : public Hashable, | |||
public std::enable_shared_from_this<OpDef> { | |||
mutable const OpTrait* m_trait = nullptr; | |||
public: | |||
virtual ~OpDef() = default; | |||
virtual std::shared_ptr<OpDef> copy() const = 0; | |||
static std::shared_ptr<OpDef> make_from_op_node( | |||
cg::OperatorNodeBase* node); | |||
@@ -40,7 +39,7 @@ public: | |||
const OpDef& def, | |||
const SmallVector<TensorPtr>& inputs); | |||
static cg::OperatorNodeBase* apply_on_var_node( | |||
static cg::VarNodeArray apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs); | |||
@@ -56,25 +55,17 @@ public: | |||
const OpTrait* trait() const; | |||
virtual size_t hash() const { | |||
mgb_throw(MegBrainError, "not implemented"); | |||
} | |||
virtual size_t hash() const; | |||
virtual bool is_same_st(const Hashable&) const { | |||
mgb_throw(MegBrainError, "not implemented"); | |||
} | |||
virtual bool is_same_st(const Hashable&) const; | |||
}; | |||
template<typename T> | |||
class OpDefImplBase : public OpDef { | |||
public: | |||
virtual std::shared_ptr<OpDef> copy() const override { | |||
return std::shared_ptr<OpDef>(new T(this->cast_final_safe<T>())); | |||
} | |||
template<typename ...Args> | |||
static std::shared_ptr<OpDef> make(const Args& ...args) { | |||
return std::shared_ptr<OpDef>(new T(args...)); | |||
static std::shared_ptr<OpDef> make(Args&& ...args) { | |||
return std::make_shared<T>(std::forward<Args>(args)...); | |||
} | |||
}; | |||
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/ops/cond_take.h | |||
* \file imperative/src/include/megbrain/imperative/ops/autogen.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
@@ -12,22 +12,15 @@ | |||
#pragma once | |||
#include "megbrain/imperative/op_def.h" | |||
#include "megdnn/opr_param_defs.h" | |||
#include "megbrain/opr/param_defs.h" | |||
namespace mgb::imperative { | |||
class CondTake : public OpDefImplBase<CondTake> { | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
public: | |||
CondTake() = default; | |||
#include "megbrain/utils/hash.h" | |||
size_t hash() const override { | |||
return reinterpret_cast<std::uintptr_t>(dyn_typeinfo()); | |||
} | |||
bool is_same_st(const Hashable& rhs) const override { | |||
return rhs.dyn_typeinfo() == dyn_typeinfo(); | |||
} | |||
namespace mgb::imperative { | |||
}; | |||
// TODO: split into separate files to avoid recompiling all | |||
// impl/ops/*.cpp on each modification of ops.td | |||
#include "./opdef.h.inl" | |||
} // namespace mgb::imperative |
@@ -1,70 +0,0 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/ops/batch_norm.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megbrain/opr/dnn/batch_norm.h" | |||
#include "megbrain/imperative/op_def.h" | |||
#include "megbrain/utils/hash.h" | |||
namespace mgb::imperative { | |||
class BatchNorm : public OpDefImplBase<BatchNorm> { | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
public: | |||
using Param = opr::BatchNorm::Param; | |||
Param::ParamDim param_dim; | |||
Param::FwdMode fwd_mode; | |||
double epsilon; | |||
double avg_factor; | |||
float scale; | |||
float bias; | |||
BatchNorm() = default; | |||
BatchNorm(const Param::ParamDim& param_dim_, const Param::FwdMode& fwd_mode_, | |||
double epsilon_, double avg_factor_, float scale_, float bias_) | |||
: param_dim(param_dim_), | |||
fwd_mode(fwd_mode_), | |||
epsilon(epsilon_), | |||
avg_factor(avg_factor_), | |||
scale(scale_), | |||
bias(bias_) {} | |||
size_t hash() const override { | |||
XXHash xxhash{}; | |||
auto append = [&xxhash](auto field){ | |||
auto hash_val = HashTrait<decltype(field)>::eval(field); | |||
xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val)); | |||
}; | |||
append(param_dim); | |||
append(fwd_mode); | |||
append(epsilon); | |||
append(avg_factor); | |||
append(scale); | |||
append(bias); | |||
return xxhash.digest(); | |||
} | |||
bool is_same_st(const Hashable& rhs_) const override { | |||
auto&& rhs = static_cast<const BatchNorm&>(rhs_); | |||
return rhs.param_dim == param_dim | |||
&& rhs.fwd_mode == fwd_mode | |||
&& rhs.epsilon == epsilon | |||
&& rhs.avg_factor == avg_factor | |||
&& rhs.scale == scale | |||
&& rhs.bias == bias; | |||
} | |||
}; | |||
} // namespace mgb::imperative |
@@ -1,35 +0,0 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/ops/broadcast.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megbrain/opr/tensor_manip.h" | |||
#include "megbrain/imperative/ops/opr_attr.h" | |||
#include "megbrain/imperative/op_def.h" | |||
namespace mgb::imperative { | |||
class Broadcast : public OpDefImplBase<Broadcast> { | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
public: | |||
Broadcast() = default; | |||
size_t hash() const override { | |||
return reinterpret_cast<std::uintptr_t>(dyn_typeinfo()); | |||
} | |||
bool is_same_st(const Hashable& rhs) const override { | |||
return true; | |||
} | |||
}; | |||
} // namespace mgb::imperative |
@@ -1,69 +0,0 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/ops/collective_comm.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megbrain/imperative/op_def.h" | |||
#include "megbrain/opr/param_defs.h" | |||
namespace mgb { | |||
namespace imperative { | |||
class CollectiveComm : public OpDefImplBase<CollectiveComm> { | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
public: | |||
using Mode = megdnn::param::CollectiveComm::Mode; | |||
CollectiveComm() = default; | |||
CollectiveComm(const std::string& key_, size_t nr_devices_, | |||
uint32_t rank_, bool is_root_, bool local_grad_, | |||
const std::string& addr_, uint32_t port_, | |||
const Mode& mode_, | |||
const DType& dtype_, const std::string& backend_, | |||
const std::string& comp_node_) | |||
: key(key_), | |||
nr_devices(nr_devices_), | |||
rank(rank_), | |||
is_root(is_root_), | |||
local_grad(local_grad_), | |||
addr(addr_), | |||
port(port_), | |||
mode(mode_), | |||
dtype(dtype_), | |||
backend(backend_), | |||
comp_node(comp_node_) {} | |||
std::string key; | |||
size_t nr_devices; | |||
uint32_t rank; | |||
bool is_root; | |||
bool local_grad; | |||
std::string addr; | |||
uint32_t port; | |||
Mode mode; | |||
DType dtype; | |||
std::string backend; | |||
std::string comp_node; | |||
size_t hash() const override; | |||
bool is_same_st(const Hashable& another) const override; | |||
auto as_tuple() const{ | |||
return std::tuple(key, nr_devices, rank, is_root, | |||
local_grad, addr, port, mode, dtype, | |||
backend, comp_node); | |||
} | |||
}; | |||
} // namespace imperative | |||
} // namespace mgb | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -1,42 +0,0 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/ops/elemwise.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megbrain/opr/basic_arith.h" | |||
#include "megbrain/imperative/op_def.h" | |||
namespace mgb::imperative { | |||
class Elemwise : public OpDefImplBase<Elemwise> { | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
public: | |||
using Mode = opr::Elemwise::Mode; | |||
using ModeTrait = megdnn::Elemwise::ModeTrait; | |||
Mode mode; | |||
Elemwise() = default; | |||
Elemwise(const Mode& mode_): mode(mode_) {} | |||
size_t hash() const override { | |||
return hash_pair_combine(mgb::hash(mode), reinterpret_cast<std::uintptr_t>(dyn_typeinfo())); | |||
} | |||
bool is_same_st(const Hashable& rhs_) const override { | |||
auto&& rhs = static_cast<const Elemwise&>(rhs_); | |||
return rhs.mode == mode; | |||
} | |||
}; | |||
} // namespace mgb::imperative |
@@ -1,77 +0,0 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/ops/io_remote.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megbrain/imperative/op_def.h" | |||
namespace mgb { | |||
namespace imperative { | |||
class RemoteSend : public OpDefImplBase<RemoteSend> { | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
public: | |||
RemoteSend() = default; | |||
RemoteSend(const std::string& key_, const std::string& addr_, | |||
uint32_t port_, uint32_t rank_to_) | |||
: key(key_), | |||
addr(addr_), | |||
port(port_), | |||
rank_to(rank_to_) {} | |||
std::string key; | |||
std::string addr; | |||
uint32_t port; | |||
uint32_t rank_to; | |||
size_t hash() const override; | |||
bool is_same_st(const Hashable& another) const override; | |||
auto as_tuple() const{ | |||
return std::tuple(key, addr, port, rank_to); | |||
} | |||
}; | |||
class RemoteRecv : public OpDefImplBase<RemoteRecv> { | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
public: | |||
RemoteRecv() = default; | |||
RemoteRecv(const std::string& key_, const std::string& addr_, | |||
uint32_t port_, uint32_t rank_from_, TensorShape shape_, | |||
CompNode cn_, const DType& dtype_) | |||
: key(key_), | |||
addr(addr_), | |||
port(port_), | |||
rank_from(rank_from_), | |||
cn(cn_), | |||
shape(shape_), | |||
dtype(dtype_) {} | |||
std::string key; | |||
std::string addr; | |||
uint32_t port; | |||
uint32_t rank_from; | |||
CompNode cn; | |||
TensorShape shape; | |||
DType dtype; | |||
size_t hash() const override; | |||
bool is_same_st(const Hashable& another) const override; | |||
auto as_tuple() const{ | |||
return std::tuple(key, addr, port, rank_from, cn, dtype, shape.to_string()); | |||
} | |||
}; | |||
} // namespace imperative | |||
} // namespace mgb | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -1,41 +0,0 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/ops/nms.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megbrain/imperative/op_def.h" | |||
namespace mgb::imperative { | |||
class NMSKeep : public OpDefImplBase<NMSKeep> { | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
public: | |||
float iou_thresh; //!< IoU threshold for overlapping | |||
uint32_t max_output; //!< max number of output boxes per batch | |||
NMSKeep() = default; | |||
NMSKeep(float iou_thresh_, uint32_t max_output_): | |||
iou_thresh(iou_thresh_), max_output(max_output_) {} | |||
size_t hash() const override { | |||
return hash_pair_combine( | |||
hash_pair_combine(mgb::hash(iou_thresh), mgb::hash(max_output)), | |||
reinterpret_cast<std::uintptr_t>(dyn_typeinfo())); | |||
} | |||
bool is_same_st(const Hashable& rhs_) const override { | |||
auto&& rhs = static_cast<const NMSKeep&>(rhs_); | |||
return rhs.iou_thresh == iou_thresh | |||
&& rhs.max_output == max_output; | |||
} | |||
}; | |||
} // namespace mgb::imperative |
@@ -1,99 +0,0 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/ops/tensor_manip.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megbrain/imperative/op_def.h" | |||
#include "megbrain/utils/hash.h" | |||
namespace mgb::imperative { | |||
class GetVarShape : public OpDefImplBase<GetVarShape> { | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
public: | |||
GetVarShape() = default; | |||
size_t hash() const override { | |||
return reinterpret_cast<std::uintptr_t>(dyn_typeinfo()); | |||
} | |||
bool is_same_st(const Hashable& rhs) const override { | |||
return rhs.dyn_typeinfo() == dyn_typeinfo(); | |||
} | |||
}; | |||
class ParamPackSplit : public OpDefImplBase<ParamPackSplit> { | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
public: | |||
ParamPackSplit() = default; | |||
ParamPackSplit(std::vector<dt_int32>& offsets_, | |||
std::vector<std::vector<size_t>>& shapes_) | |||
: offsets(offsets_), shapes(shapes_) {} | |||
std::vector<dt_int32> offsets; | |||
std::vector<std::vector<size_t>> shapes; | |||
size_t hash() const override { | |||
XXHash builder; | |||
for (auto&& offset : offsets) { | |||
builder.update(&offset, sizeof(offset)); | |||
} | |||
auto&& offset_cnt = offsets.size(); | |||
builder.update(&offset_cnt, sizeof(offset_cnt)); | |||
for (auto&& shape : shapes) { | |||
for (auto&& dim_len : shape) { | |||
builder.update(&dim_len, sizeof(dim_len)); | |||
} | |||
auto&& dim_cnt = shape.size(); | |||
builder.update(&dim_cnt, sizeof(dim_cnt)); | |||
} | |||
auto&& shape_cnt = shapes.size(); | |||
builder.update(&shape_cnt, sizeof(shape_cnt)); | |||
return builder.digest(); | |||
} | |||
bool is_same_st(const Hashable& rhs) const override { | |||
auto&& pps = rhs.cast_final_safe<ParamPackSplit>(); | |||
return offsets == pps.offsets && shapes == pps.shapes; | |||
} | |||
}; | |||
class ParamPackConcat : public OpDefImplBase<ParamPackConcat> { | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
public: | |||
ParamPackConcat() = default; | |||
ParamPackConcat(std::vector<dt_int32>& offsets_) | |||
: offsets(offsets_) {} | |||
std::vector<dt_int32> offsets; | |||
size_t hash() const override { | |||
XXHash builder; | |||
for (auto&& offset : offsets) { | |||
builder.update(&offset, sizeof(offset)); | |||
} | |||
auto&& offset_cnt = offsets.size(); | |||
builder.update(&offset_cnt, sizeof(offset_cnt)); | |||
return builder.digest(); | |||
} | |||
bool is_same_st(const Hashable& rhs) const override { | |||
auto&& ppc = rhs.cast_final_safe<ParamPackConcat>(); | |||
return offsets == ppc.offsets; | |||
} | |||
}; | |||
} // namespace mgb::imperative |
@@ -29,18 +29,18 @@ TEST(TestImperative, BackwardGraphBasic) { | |||
using Param = opr::Elemwise::Param; | |||
Param param{Param::Mode::MUL}; | |||
OprAttr attr{"Elemwise", {}, {}}; | |||
attr.param.write_pod(param); | |||
auto attr = OprAttr::make("Elemwise"); | |||
attr->cast_final_safe<OprAttr>().param.write_pod(param); | |||
SmallVector<LogicalTensorDesc> input_descs; | |||
for (auto&& i : inputs) { | |||
input_descs.push_back({i->layout(), i->comp_node()}); | |||
} | |||
auto result = OpDef::make_backward_graph(attr, input_descs, {true, true}, {true}); | |||
auto result = OpDef::make_backward_graph(*attr, input_descs, {true, true}, {true}); | |||
auto&& save_for_backward = result.save_for_backward; | |||
auto&& input_has_grad = result.input_has_grad; | |||
auto outputs = OpDef::apply_on_physical_tensor(attr, inputs); | |||
auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs); | |||
inputs.push_back(outputs[0]); | |||
hvs.push_back(*gen({42})); | |||
inputs.push_back(Tensor::make(hvs.back())); | |||
@@ -82,16 +82,16 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||
SmallVector<TensorPtr> inputs; | |||
inputs.push_back(a); | |||
OprAttr attr{"Identity", {}, {}}; | |||
attr.param.write_pod<megdnn::param::Empty>({}); | |||
auto attr = OprAttr::make("Identity"); | |||
attr->cast_final_safe<OprAttr>().param.write_pod<megdnn::param::Empty>({}); | |||
SmallVector<LogicalTensorDesc> input_descs; | |||
input_descs.push_back({a->layout(), a->comp_node()}); | |||
auto result = OpDef::make_backward_graph(attr, input_descs, {true}, {true}); | |||
auto result = OpDef::make_backward_graph(*attr, input_descs, {true}, {true}); | |||
auto&& save_for_backward = result.save_for_backward; | |||
auto&& input_has_grad = result.input_has_grad; | |||
auto outputs = OpDef::apply_on_physical_tensor(attr, inputs); | |||
auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs); | |||
inputs.push_back(outputs[0]); | |||
inputs.push_back(dc); | |||
mgb_assert(save_for_backward.size() == inputs.size()); | |||
@@ -10,7 +10,7 @@ | |||
*/ | |||
#include "./helper.h" | |||
#include "megbrain/imperative/ops/collective_comm.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/opr/mm_handler.h" | |||
using namespace mgb; | |||
@@ -32,12 +32,13 @@ TEST(TestImperative, AllReduceBasic) { | |||
} | |||
auto run = [&](std::shared_ptr<HostTensorND> hnd, uint32_t idx) { | |||
imperative::CollectiveComm | |||
def{"all_reduce", 2, idx, idx==0, false, server_addr, port, | |||
auto def = | |||
imperative::CollectiveComm::make( | |||
megdnn::param::CollectiveComm::Mode::ALL_REDUCE_SUM, | |||
dtype::Float32(), "nccl", ""}; | |||
"all_reduce", 2, idx, idx==0, false, server_addr, port, | |||
dtype::Float32(), "nccl", ""); | |||
auto inp = Tensor::make(*hnd); | |||
auto oup = OpDef::apply_on_physical_tensor(def, {inp}); | |||
auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); | |||
HostTensorND host_v; | |||
host_v.copy_from(oup[0]->dev_tensor()).sync(); | |||
MGB_ASSERT_TENSOR_NEAR(*expect, host_v, 1e-6); | |||
@@ -10,7 +10,7 @@ | |||
*/ | |||
#include "./helper.h" | |||
#include "megbrain/imperative/ops/cond_take.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
using namespace mgb; | |||
using namespace imperative; | |||
@@ -119,7 +119,7 @@ void OprChecker::run(std::vector<InputSpec> inp_keys) { | |||
}, inp_keys[i]); | |||
sym_inp[i] = opr::SharedDeviceTensor::make(*graph, host_inp[i]).node(); | |||
} | |||
auto sym_oup = OpDef::apply_on_var_node(*m_op, sym_inp)->usable_output(); | |||
auto sym_oup = OpDef::apply_on_var_node(*m_op, sym_inp); | |||
size_t nr_oups = sym_oup.size(); | |||
ComputingGraph::OutputSpec oup_spec(nr_oups); | |||
SmallVector<HostTensorND> host_sym_oup(nr_oups); | |||
@@ -10,7 +10,7 @@ | |||
*/ | |||
#include "./helper.h" | |||
#include "megbrain/imperative/ops/io_remote.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/opr/mm_handler.h" | |||
using namespace mgb; | |||
@@ -33,24 +33,19 @@ TEST(TestImperative, IORemote) { | |||
} | |||
auto run_send = [&](std::shared_ptr<HostTensorND> hnd) { | |||
imperative::RemoteSend def{"io_remote_test", server_addr, port, 1}; | |||
auto def = imperative::RemoteSend::make( | |||
"io_remote_test", server_addr, port, 1); | |||
auto inp = Tensor::make(*hnd); | |||
auto oup = OpDef::apply_on_physical_tensor(def, {inp}); | |||
auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); | |||
}; | |||
auto run_recv = [&](std::shared_ptr<HostTensorND> hnd) { | |||
// auto&& shape = std::initializer_list{vector_size}; | |||
imperative::RemoteRecv def{"io_remote_test", | |||
server_addr, | |||
port, | |||
0, | |||
{ | |||
vector_size, | |||
}, | |||
CompNode::load("gpu1"), | |||
dtype::Float32()}; | |||
auto def = imperative::RemoteRecv::make( | |||
"io_remote_test", server_addr, port, 0, | |||
CompNode::load("gpu1"), TensorShape{vector_size}, | |||
dtype::Float32()); | |||
auto inp = Tensor::make(*hnd); | |||
auto oup = OpDef::apply_on_physical_tensor(def, {inp}); | |||
auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); | |||
HostTensorND host_v; | |||
host_v.copy_from(oup[0]->dev_tensor()).sync(); | |||
MGB_ASSERT_TENSOR_NEAR(*expect, host_v, 1e-6); | |||
@@ -0,0 +1,14 @@ | |||
# mgb tablegen executable | |||
set(TABLE_TARGET mgb-mlir-autogen) | |||
add_executable(${TABLE_TARGET} autogen.cpp) | |||
target_include_directories(${TABLE_TARGET} PRIVATE ${MLIR_LLVM_INCLUDE_DIR}) | |||
target_link_libraries(${TABLE_TARGET} PRIVATE LLVMTableGen MLIRTableGen LLVMSupport) | |||
set(MGB_TABLEGEN_EXE ${TABLE_TARGET}) | |||
# generate megbrain opdef c header and python bindings | |||
set(LLVM_TARGET_DEFINITIONS ${MGE_IR_DIR}/ops.td) | |||
tablegen(MGB opdef.h.inl ${MGE_IR_INCLUDE_DIRS} "--gen-cpp-header") | |||
tablegen(MGB opdef.cpp.inl ${MGE_IR_INCLUDE_DIRS} "--gen-cpp-body") | |||
tablegen(MGB opdef.py.inl ${MGE_IR_INCLUDE_DIRS} "--gen-python-binding") | |||
add_custom_target(mgb_opdef ALL DEPENDS opdef.h.inl opdef.cpp.inl opdef.py.inl param_defs_tblgen) | |||
set(MGB_OPDEF_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR} PARENT_SCOPE) |
@@ -0,0 +1,383 @@ | |||
#include <iostream> | |||
#include <unordered_map> | |||
#include <functional> | |||
#include "./helper.h" | |||
using llvm::raw_ostream; | |||
using llvm::RecordKeeper; | |||
enum ActionType { | |||
None, | |||
CppHeader, | |||
CppBody, | |||
Pybind | |||
}; | |||
// NOLINTNEXTLINE | |||
llvm::cl::opt<ActionType> action( | |||
llvm::cl::desc("Action to perform:"), | |||
llvm::cl::values(clEnumValN(CppHeader, "gen-cpp-header", | |||
"Generate operator cpp header"), | |||
clEnumValN(CppBody, "gen-cpp-body", | |||
"Generate operator cpp body"), | |||
clEnumValN(Pybind, "gen-python-binding", | |||
"Generate pybind11 python bindings"))); | |||
using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase; | |||
using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin; | |||
using MgbHashableAttr = mlir::tblgen::MgbHashableAttrMixin; | |||
using MgbAliasAttr = mlir::tblgen::MgbAliasAttrMixin; | |||
using MgbOp = mlir::tblgen::MgbOpBase; | |||
using MgbHashableOp = mlir::tblgen::MgbHashableOpMixin; | |||
llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) { | |||
// Note: we have already registered the corresponding attr wrappers | |||
// for following basic ctypes so we needn't handle them here | |||
/* auto&& attr_type_name = attr.getAttrDefName(); | |||
if (attr_type_name == "UI32Attr") { | |||
return "uint32_t"; | |||
} | |||
if (attr_type_name == "UI64Attr") { | |||
return "uint64_t"; | |||
} | |||
if (attr_type_name == "I32Attr") { | |||
return "int32_t"; | |||
} | |||
if (attr_type_name == "F32Attr") { | |||
return "float"; | |||
} | |||
if (attr_type_name == "F64Attr") { | |||
return "double"; | |||
} | |||
if (attr_type_name == "StrAttr") { | |||
return "std::string"; | |||
} | |||
if (attr_type_name == "BoolAttr") { | |||
return "bool"; | |||
}*/ | |||
auto&& attr = llvm::cast<MgbAttrWrapper>(attr_); | |||
if (auto e = llvm::dyn_cast<MgbEnumAttr>(&attr)) { | |||
return e->getEnumName(); | |||
} | |||
return attr.getUnderlyingType(); | |||
} | |||
static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) { | |||
os << formatv( | |||
"class {0} : public OpDefImplBase<{0}> {{\n" | |||
" MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n" | |||
"public:\n", | |||
op.getCppClassName() | |||
); | |||
// handle enum alias | |||
for (auto &&i : op.getMgbAttributes()) { | |||
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
os << formatv( | |||
" using {0} = {1};\n", | |||
attr->getEnumName(), attr->getUnderlyingType() | |||
); | |||
} | |||
} | |||
for (auto &&i : op.getMgbAttributes()) { | |||
auto defaultValue = i.attr.getDefaultValue().str(); | |||
if (!defaultValue.empty()) { | |||
defaultValue = formatv(" = {0}", defaultValue); | |||
} | |||
os << formatv( | |||
" {0} {1}{2};\n", | |||
attr_to_ctype(i.attr), i.name, defaultValue | |||
); | |||
} | |||
auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) { | |||
os << formatv( | |||
" {0}({1}){2}{3}\n", | |||
op.getCppClassName(), paramList, memInitList, body | |||
); | |||
}; | |||
gen_ctor("", "", " = default;"); | |||
if (!op.getMgbAttributes().empty()) { | |||
std::vector<std::string> paramList, initList; | |||
for (auto &&i : op.getMgbAttributes()) { | |||
paramList.push_back(formatv( | |||
"{0} {1}_", attr_to_ctype(i.attr), i.name | |||
)); | |||
initList.push_back(formatv( | |||
"{0}({0}_)", i.name | |||
)); | |||
} | |||
gen_ctor(llvm::join(paramList, ", "), | |||
": " + llvm::join(initList, ", "), | |||
" {}"); | |||
} | |||
auto packedParams = op.getPackedParams(); | |||
if (!packedParams.empty()) { | |||
std::vector<std::string> paramList, initList; | |||
for (auto &&p : packedParams) { | |||
auto&& paramFields = p.getFields(); | |||
auto&& paramType = p.getFullName(); | |||
auto&& paramName = formatv("packed_param_{0}", paramList.size()); | |||
paramList.push_back( | |||
paramFields.empty() ? paramType.str() | |||
: formatv("{0} {1}", paramType, paramName) | |||
); | |||
for (auto&& i : paramFields) { | |||
initList.push_back(formatv( | |||
"{0}({1}.{0})", i.name, paramName | |||
)); | |||
} | |||
} | |||
for (auto&& i : op.getExtraArguments()) { | |||
paramList.push_back(formatv( | |||
"{0} {1}_", attr_to_ctype(i.attr), i.name | |||
)); | |||
initList.push_back(formatv( | |||
"{0}({0}_)", i.name | |||
)); | |||
} | |||
gen_ctor(llvm::join(paramList, ", "), | |||
initList.empty() ? "" : ": " + llvm::join(initList, ", "), | |||
" {}"); | |||
} | |||
if (!packedParams.empty()) { | |||
for (auto&& p : packedParams) { | |||
auto accessor = p.getAccessor(); | |||
if (!accessor.empty()) { | |||
os << formatv( | |||
" {0} {1}() const {{\n", | |||
p.getFullName(), accessor | |||
); | |||
std::vector<llvm::StringRef> fields; | |||
for (auto&& i : p.getFields()) { | |||
fields.push_back(i.name); | |||
} | |||
os << formatv( | |||
" return {{{0}};\n", | |||
llvm::join(fields, ", ") | |||
); | |||
os << " }\n"; | |||
} | |||
} | |||
} | |||
if (auto decl = op.getExtraOpdefDecl()) { | |||
os << decl.getValue(); | |||
} | |||
os << formatv( | |||
"};\n\n" | |||
); | |||
} | |||
static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | |||
auto&& className = op.getCppClassName(); | |||
os << formatv( | |||
"MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className | |||
); | |||
auto formatMethImpl = [&](auto&& meth) { | |||
return formatv( | |||
"{0}_{1}_impl", className, meth | |||
); | |||
}; | |||
std::vector<std::string> methods; | |||
if (auto hashable = llvm::dyn_cast<MgbHashableOp>(&op)) { | |||
os << "namespace {\n"; | |||
// generate hash() | |||
mlir::tblgen::FmtContext ctx; | |||
os << formatv( | |||
"size_t {0}(const OpDef& def_) {{\n", | |||
formatMethImpl("hash") | |||
); | |||
os << formatv( | |||
" auto op_ = def_.cast_final_safe<{0}>();\n" | |||
" static_cast<void>(op_);\n", | |||
className | |||
); | |||
ctx.withSelf("op_"); | |||
os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx); | |||
os << "}\n"; | |||
// generate is_same_st() | |||
os << formatv( | |||
"bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n", | |||
formatMethImpl("is_same_st") | |||
); | |||
os << formatv( | |||
" auto a_ = lhs_.cast_final_safe<{0}>(),\n" | |||
" b_ = rhs_.cast_final_safe<{0}>();\n" | |||
" static_cast<void>(a_);\n" | |||
" static_cast<void>(b_);\n", | |||
className | |||
); | |||
os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_"); | |||
os << "}\n"; | |||
os << "} // anonymous namespace\n"; | |||
methods.push_back("hash"); | |||
methods.push_back("is_same_st"); | |||
} | |||
if (!methods.empty()) { | |||
os << formatv( | |||
"OP_TRAIT_REG({0}, {0})", op.getCppClassName() | |||
); | |||
for (auto&& i : methods) { | |||
os << formatv( | |||
"\n .{0}({1})", i, formatMethImpl(i) | |||
); | |||
} | |||
os << ";\n\n"; | |||
} | |||
} | |||
struct PybindContext { | |||
std::unordered_map<unsigned int, std::string> enumAlias; | |||
}; | |||
static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, PybindContext& ctx) { | |||
auto class_name = op.getCppClassName(); | |||
os << formatv( | |||
"py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n", | |||
class_name | |||
); | |||
for (auto&& i : op.getMgbAttributes()) { | |||
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
unsigned int enumID; | |||
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||
auto&& aliasBase = alias->getAliasBase(); | |||
enumID = | |||
llvm::cast<MgbEnumAttr>(aliasBase) | |||
.getBaseRecord()->getID(); | |||
} else { | |||
enumID = attr->getBaseRecord()->getID(); | |||
} | |||
auto&& enumAlias = ctx.enumAlias; | |||
auto&& iter = enumAlias.find(enumID); | |||
if (iter == enumAlias.end()) { | |||
os << formatv( | |||
"py::enum_<{0}::{1}>({0}Inst, \"{1}\")", | |||
class_name, attr->getEnumName() | |||
); | |||
std::vector<std::string> body; | |||
for (auto&& i: attr->getEnumMembers()) { | |||
os << formatv( | |||
"\n .value(\"{2}\", {0}::{1}::{2})", | |||
class_name, attr->getEnumName(), i | |||
); | |||
body.push_back(formatv( | |||
"if (str == \"{2}\") return {0}::{1}::{2};", | |||
class_name, attr->getEnumName(), i | |||
)); | |||
} | |||
os << formatv( | |||
"\n .def(py::init([](const std::string& in) {" | |||
"\n auto&& str = normalize_enum(in);" | |||
"\n {0}" | |||
"\n throw py::cast_error(\"invalid enum value \" + in);" | |||
"\n }));\n", | |||
llvm::join(body, "\n ") | |||
); | |||
os << formatv( | |||
"py::implicitly_convertible<std::string, {0}::{1}>();\n\n", | |||
class_name, attr->getEnumName() | |||
); | |||
enumAlias.emplace(enumID, formatv( | |||
"{0}Inst.attr(\"{1}\")", class_name, attr->getEnumName() | |||
)); | |||
} else { | |||
os << formatv( | |||
"{0}Inst.attr(\"{1}\") = {2};\n\n", | |||
class_name, attr->getEnumName(), iter->second | |||
); | |||
} | |||
} | |||
} | |||
// generate op class binding | |||
os << formatv("{0}Inst", class_name); | |||
bool hasDefaultCtor = op.getMgbAttributes().empty(); | |||
if (!hasDefaultCtor) { | |||
os << "\n .def(py::init<"; | |||
std::vector<llvm::StringRef> targs; | |||
for (auto &&i : op.getMgbAttributes()) { | |||
targs.push_back(i.attr.getReturnType()); | |||
} | |||
os << llvm::join(targs, ", "); | |||
os << ">()"; | |||
for (auto &&i : op.getMgbAttributes()) { | |||
os << formatv(", py::arg(\"{0}\")", i.name); | |||
auto defaultValue = i.attr.getDefaultValue(); | |||
if (!defaultValue.empty()) { | |||
os << formatv(" = {0}", defaultValue); | |||
} else { | |||
hasDefaultCtor = true; | |||
} | |||
} | |||
os << ")"; | |||
} | |||
if (hasDefaultCtor) { | |||
os << "\n .def(py::init<>())"; | |||
} | |||
for (auto &&i : op.getMgbAttributes()) { | |||
os << formatv( | |||
"\n .def_readwrite(\"{0}\", &{1}::{0})", | |||
i.name, class_name | |||
); | |||
} | |||
os << ";\n\n"; | |||
} | |||
static void for_each_operator(raw_ostream &os, RecordKeeper &keeper, | |||
std::function<void(raw_ostream&, MgbOp&)> callback) { | |||
auto op_base_class = keeper.getClass("Op"); | |||
ASSERT(op_base_class, "could not find base class Op"); | |||
for (auto&& i: keeper.getDefs()) { | |||
auto&& r = i.second; | |||
if (r->isSubClassOf(op_base_class)) { | |||
auto op = mlir::tblgen::Operator(r.get()); | |||
if (op.getDialectName().str() == "mgb") { | |||
std::cerr << "\033[34;15m" << "Generating " << r->getName().str() << "\033[0m" << std::endl; | |||
callback(os, llvm::cast<MgbOp>(op)); | |||
} | |||
} | |||
} | |||
} | |||
static bool gen_op_def_c_header(raw_ostream &os, RecordKeeper &keeper) { | |||
for_each_operator(os, keeper, gen_op_def_c_header_single); | |||
return false; | |||
} | |||
static bool gen_op_def_c_body(raw_ostream &os, RecordKeeper &keeper) { | |||
for_each_operator(os, keeper, gen_op_def_c_body_single); | |||
return false; | |||
} | |||
static bool gen_op_def_pybind11(raw_ostream &os, RecordKeeper &keeper) { | |||
PybindContext ctx; | |||
using namespace std::placeholders; | |||
for_each_operator(os, keeper, | |||
std::bind(gen_op_def_pybind11_single, _1, _2, std::ref(ctx))); | |||
return false; | |||
} | |||
int main(int argc, char **argv) { | |||
llvm::InitLLVM y(argc, argv); | |||
llvm::cl::ParseCommandLineOptions(argc, argv); | |||
if (action == ActionType::CppHeader) { | |||
return TableGenMain(argv[0], &gen_op_def_c_header); | |||
} | |||
if (action == ActionType::CppBody) { | |||
return TableGenMain(argv[0], &gen_op_def_c_body); | |||
} | |||
if (action == ActionType::Pybind) { | |||
return TableGenMain(argv[0], &gen_op_def_pybind11); | |||
} | |||
return -1; | |||
} |
@@ -0,0 +1,228 @@ | |||
#include <string> | |||
#include <vector> | |||
#include "llvm/Support/CommandLine.h" | |||
#include "llvm/Support/FormatVariadic.h" | |||
#include "llvm/Support/InitLLVM.h" | |||
#include "llvm/Support/Signals.h" | |||
#include "llvm/TableGen/Main.h" | |||
#include "llvm/TableGen/Record.h" | |||
#include "llvm/TableGen/TableGenBackend.h" | |||
#include "mlir/TableGen/Attribute.h" | |||
#include "mlir/TableGen/Format.h" | |||
#include "mlir/TableGen/Operator.h" | |||
using llvm::formatv; | |||
using llvm::StringRef; | |||
using llvm::Record; | |||
#define ASSERT(stmt, msg) \ | |||
if (!(stmt)) { \ | |||
std::cerr << "\033[1;31m" \ | |||
<< "tablegen autogen abort due to: " << msg \ | |||
<< "\033[0m" << std::endl; \ | |||
exit(1); \ | |||
} | |||
namespace mlir { | |||
namespace tblgen { | |||
template<typename ConcreteType> | |||
struct MgbInterface : public ConcreteType { | |||
MgbInterface() = delete; | |||
MgbInterface(const MgbInterface&) = delete; | |||
MgbInterface(MgbInterface&&) = delete; | |||
~MgbInterface() = delete; | |||
}; | |||
struct MgbAttrWrapperBase : public MgbInterface<Attribute> { | |||
private: | |||
struct RecordVisitor : public MgbInterface<Constraint> { | |||
public: | |||
static bool classof(const Constraint*) { | |||
return true; | |||
} | |||
const llvm::Record* getDef() const { | |||
return def; | |||
} | |||
}; | |||
public: | |||
static bool classof(const Attribute* attr) { | |||
return attr->isSubClassOf("MgbAttrWrapperBase"); | |||
} | |||
const llvm::Record* getBaseRecord() const { | |||
auto baseAttr = getBaseAttr(); | |||
return llvm::cast<RecordVisitor>(baseAttr).getDef(); | |||
} | |||
llvm::StringRef getUnderlyingType() const { | |||
return def->getValueAsString("underlyingType"); | |||
} | |||
}; | |||
struct MgbEnumAttrMixin : public MgbAttrWrapperBase { | |||
static bool classof(const Attribute* attr) { | |||
return attr->getBaseAttr().isSubClassOf("MgbEnumAttrMixin"); | |||
} | |||
llvm::StringRef getParentNamespace() const { | |||
return getBaseRecord()->getValueAsString("parentNamespce"); | |||
} | |||
llvm::StringRef getEnumName() const { | |||
return getBaseRecord()->getValueAsString("enumName"); | |||
} | |||
std::vector<StringRef> getEnumMembers() const { | |||
return getBaseRecord()->getValueAsListOfStrings("enumMembers"); | |||
} | |||
}; | |||
struct MgbHashableAttrMixin : public MgbAttrWrapperBase { | |||
static bool classof(const Attribute* attr) { | |||
return attr->getBaseAttr().isSubClassOf("MgbHashableAttrMixin"); | |||
} | |||
llvm::StringRef getHashFunctionTemplate() const { | |||
return getBaseRecord()->getValueAsString("hashFunction"); | |||
} | |||
llvm::StringRef getCmpFunctionTemplate() const { | |||
return getBaseRecord()->getValueAsString("cmpFunction"); | |||
} | |||
}; | |||
struct MgbAliasAttrMixin : public MgbAttrWrapperBase { | |||
static bool classof(const Attribute* attr) { | |||
return attr->getBaseAttr().isSubClassOf("MgbAliasAttrMixin"); | |||
} | |||
Attribute getAliasBase() const { | |||
return Attribute(getBaseRecord()->getValueAsDef("aliasBase")); | |||
} | |||
}; | |||
class MgbPackedParam { | |||
public: | |||
MgbPackedParam(Record* def_): def(def_) { | |||
auto&& dag = def->getValueAsDag("fields"); | |||
for (size_t i = 0; i < dag->getNumArgs(); ++ i) { | |||
fields.push_back({ | |||
dag->getArgNameStr(i), | |||
Attribute(llvm::cast<llvm::DefInit>(dag->getArg(i))) | |||
}); | |||
} | |||
} | |||
llvm::StringRef getFullName() const { | |||
return def->getValueAsString("fullName"); | |||
} | |||
std::vector<NamedAttribute> getFields() const { | |||
return fields; | |||
} | |||
llvm::StringRef getAccessor() const { | |||
return def->getValueAsString("paramAccessor"); | |||
} | |||
private: | |||
std::vector<NamedAttribute> fields; | |||
Record* def; | |||
}; | |||
struct MgbOpBase : public MgbInterface<Operator> { | |||
static bool isPackedParam(Record* def) { | |||
return def->isSubClassOf("MgbPackedParamBase"); | |||
} | |||
public: | |||
static bool classof(const Operator* op) { | |||
return op->getDef().isSubClassOf("MgbOp"); | |||
} | |||
std::vector<NamedAttribute> getMgbAttributes() const { | |||
std::vector<NamedAttribute> ret; | |||
for (auto&& i: getAttributes()) { | |||
if (isa<MgbAttrWrapperBase>(i.attr)) { | |||
ret.push_back(i); | |||
} | |||
} | |||
return ret; | |||
} | |||
std::vector<NamedAttribute> getExtraArguments() const { | |||
std::vector<NamedAttribute> ret; | |||
auto&& dag = getDef().getValueAsDag("extraArguments"); | |||
for (size_t i = 0; i < dag->getNumArgs(); ++ i) { | |||
ret.push_back({ | |||
dag->getArgNameStr(i), | |||
Attribute(llvm::cast<llvm::DefInit>(dag->getArg(i))) | |||
}); | |||
} | |||
return ret; | |||
} | |||
llvm::Optional<StringRef> getExtraOpdefDecl() const { | |||
return getDef().getValueAsOptionalString("extraOpdefDecl"); | |||
} | |||
std::vector<MgbPackedParam> getPackedParams() const { | |||
std::vector<MgbPackedParam> ret; | |||
for (auto&& i : getDef().getValueAsListOfDefs("dnnParams")) { | |||
if (isPackedParam(i)) { | |||
ret.emplace_back(i); | |||
} | |||
} | |||
return ret; | |||
} | |||
}; | |||
struct MgbHashableOpMixin : public MgbOpBase { | |||
private: | |||
std::string getDefaultHashFunction() const { | |||
std::string body = " size_t val = mgb::hash($_self.dyn_typeinfo());\n"; | |||
if (!getMgbAttributes().empty()) { | |||
auto getHashFunc = [&](auto&& iter) { | |||
auto&& attr = llvm::cast<MgbHashableAttrMixin>(iter.attr); | |||
return attr.getHashFunctionTemplate(); | |||
}; | |||
mlir::tblgen::FmtContext ctx; | |||
for (auto&& it: getMgbAttributes()) { | |||
body += formatv( | |||
" val = mgb::hash_pair_combine(val, {0});\n", | |||
mlir::tblgen::tgfmt(getHashFunc(it), &ctx, "$_self." + it.name) | |||
); | |||
} | |||
} | |||
body += " return val;\n"; | |||
return body; | |||
} | |||
std::string getDefaultCmpFunction() const { | |||
std::string body; | |||
if (!getMgbAttributes().empty()) { | |||
mlir::tblgen::FmtContext ctx; | |||
for (auto&& it : getMgbAttributes()) { | |||
auto&& attr = llvm::cast<MgbHashableAttrMixin>(it.attr); | |||
body += formatv( | |||
" if ({0}) return false;\n", | |||
mlir::tblgen::tgfmt(attr.getCmpFunctionTemplate(), | |||
&ctx, "$0." + it.name, "$1." + it.name) | |||
); | |||
} | |||
} | |||
body += " return true;\n"; | |||
return body; | |||
} | |||
public: | |||
static bool classof(const Operator* op) { | |||
return op->getDef().isSubClassOf("MgbHashableOpMixin"); | |||
} | |||
std::string getHashFunctionTemplate() const { | |||
if (auto f = getDef().getValueAsOptionalString("hashFunction")) { | |||
return f.getValue().str(); | |||
} | |||
return getDefaultHashFunction(); | |||
} | |||
std::string getCmpFunctionTemplate() const { | |||
if (auto f = getDef().getValueAsOptionalString("cmpFunction")) { | |||
return f.getValue().str(); | |||
} | |||
return getDefaultCmpFunction(); | |||
} | |||
}; | |||
} // namespace tblgen | |||
} // namespace mlir |
@@ -11,7 +11,7 @@ endif() | |||
# TODO: turn python binding into a static/object library | |||
add_executable(imperative_test ${SOURCES} ${SRCS}) | |||
target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include) | |||
target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR}) | |||
# Python binding | |||
target_include_directories(imperative_test PRIVATE ${MODULE_SRC_INCLUDE} ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR}) | |||
@@ -0,0 +1,257 @@ | |||
/** | |||
* \file src/core/include/megbrain/ir/base.td | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#ifndef MGB_BASE | |||
#define MGB_BASE | |||
include "mlir/IR/OpBase.td" | |||
def Mgb_Dialect : Dialect { | |||
let name = "mgb"; | |||
let cppNamespace = "mgb::dialect"; | |||
} | |||
// -- mgb Attr mixin | |||
class MgbAttrWrapperBase<string className> { | |||
string underlyingType = className; | |||
int recursionDepth = 0; | |||
} | |||
class MgbHashableAttrMixin { | |||
string hashFunction = "mgb::hash($0)"; | |||
// return 0 for eq, else for ne | |||
string cmpFunction = "$0 != $1"; | |||
} | |||
class MgbEnumAttrMixin<string namespace, string name, list<string> members> { | |||
string parentNamespace = namespace; | |||
string enumName = name; | |||
list<string> enumMembers = members; | |||
} | |||
class MgbAttrWrapper; | |||
class MgbAliasAttrMixin<Attr base> { | |||
Attr aliasBase = base; | |||
} | |||
// -- mgb custom Attr | |||
// TODO: CPred and description | |||
class MgbAttrWrapper<string className>: | |||
Attr<CPred<"true">, "TODO">, MgbAttrWrapperBase<className> { | |||
let returnType = underlyingType; | |||
} | |||
class HashableAttr<string className>: | |||
MgbAttrWrapper<className>, MgbHashableAttrMixin; | |||
// -- basic types | |||
class MgbIntegerAttrBase<string CType> : HashableAttr<CType> { | |||
let storageType = "::mlir::IntegerAttr"; | |||
} | |||
class MgbSignlessIntegerAttrBase<string CType> : MgbIntegerAttrBase<CType> { | |||
let convertFromStorage = "static_cast<" # underlyingType # ">($_self.getInt())"; | |||
let constBuilderCall = "$_builder.getIntegerAttr($_builder.getIntegerType(sizeof(" # underlyingType # ") * 4), $0)"; | |||
} | |||
class MgbSignedIntegerAttrBase<string CType> : MgbIntegerAttrBase<CType> { | |||
let convertFromStorage = "static_cast<" # underlyingType # ">($_self.getSInt())"; | |||
let constBuilderCall = "$_builder.getIntegerAttr($_builder.getIntegerType(sizeof(" # underlyingType # ") * 4, true), $0)"; | |||
} | |||
class MgbUnsignedIntegerAttrBase<string CType> : MgbIntegerAttrBase<CType> { | |||
let convertFromStorage = "static_cast<" # underlyingType # ">($_self.getUInt())"; | |||
let constBuilderCall = "$_builder.getIntegerAttr($_builder.getIntegerType(sizeof(" # underlyingType # ") * 4, false), $0)"; | |||
} | |||
def MgbI8Attr: MgbSignlessIntegerAttrBase<"int8_t">; | |||
def MgbI32Attr: MgbSignlessIntegerAttrBase<"int32_t">; | |||
def MgbI64Attr: MgbSignlessIntegerAttrBase<"int64_t">; | |||
def MgbUI32Attr: MgbUnsignedIntegerAttrBase<"uint32_t">; | |||
def MgbUI64Attr: MgbUnsignedIntegerAttrBase<"uint64_t">; | |||
def MgbSizeTAddr: MgbUnsignedIntegerAttrBase<"size_t">; | |||
class MgbFloatAttrBase<string CType, string DType> : HashableAttr<CType> { | |||
let storageType = "::mlir::FloatAttr"; | |||
let convertFromStorage = "static_cast<" # underlyingType # ">($_self.getValueAsDouble())"; | |||
let constBuilderCall = "$_builder.getFloatAttr($_builder.get" # DType # "Type(), $0)"; | |||
} | |||
def MgbF32Attr : MgbFloatAttrBase<"float", "F32">; | |||
def MgbF64Attr : MgbFloatAttrBase<"double", "F64">; | |||
def MgbBoolAttr : HashableAttr<"bool"> { | |||
let storageType = "::mlir::BoolAttr"; | |||
let constBuilderCall = "$_builder.getBoolAttr($0)"; | |||
} | |||
def MgbStringAttr : HashableAttr<"std::string"> { | |||
let storageType = "::mlir::StringAttr"; | |||
let convertFromStorage = "$_self.getValue().str()"; | |||
let constBuilderCall = "$_builder.getStringAttr($0)"; // llvm::StringRef implicit ctor | |||
} | |||
class MgbArrayAttr<MgbAttrWrapper elem>: | |||
HashableAttr<"std::vector<" # elem.underlyingType # ">"> { | |||
let storageType = "::mlir::ArrayAttr"; | |||
let recursionDepth = !add(elem.recursionDepth, 1); | |||
let convertFromStorage = | |||
"[&] {\n" | |||
" " # underlyingType # " ret" # recursionDepth # ";\n" | |||
" std::for_each($_self.begin(), $_self.end(), [&](auto&& i" # recursionDepth # ") {\n" | |||
" ret" # recursionDepth # ".push_back(\n" | |||
" " # !subst("$_self", "i" # recursionDepth # ".template cast<" # elem.storageType # ">()", "" # elem.convertFromStorage) # "\n" | |||
" );\n" | |||
" });\n" | |||
" return ret" # recursionDepth # ";}()"; | |||
let constBuilderCall = | |||
"[&] {\n" | |||
" std::vector<mlir::Attribute> ret" # recursionDepth # ";\n" | |||
" std::for_each($0.begin(), $0.end(), [&](auto&& i" # recursionDepth # ") {\n" | |||
" ret" # recursionDepth # ".push_back(\n" | |||
" " # !subst("$0", "i" # recursionDepth, "" # elem.constBuilderCall) # "\n" | |||
" );\n" | |||
" });\n" | |||
" return $_builder.getArrayAttr(ret" # recursionDepth # ");" | |||
"}()"; | |||
} | |||
defvar EmptyStrList = !listsplat("", 0); | |||
class StrListAppend<list<string> l, string s> { | |||
list<string> r = !listconcat(l, !listsplat(s, 1)); | |||
} | |||
class TupleConvertFromStorage<MgbAttrWrapper attr, int idx> { | |||
string r = !subst( | |||
"$_self", | |||
"$_self[" # !cast<string>(idx) # "].template cast<"# attr.storageType #">()", | |||
"" # attr.convertFromStorage); | |||
} | |||
class TupleConstBuilderCall<MgbAttrWrapper attr, int idx> { | |||
string r = !subst( | |||
"$0", | |||
"std::get<" # !cast<string>(idx) # ">($0)", | |||
"" # attr.constBuilderCall); | |||
} | |||
class ApplyTupleConvertFromStorage<list<MgbAttrWrapper> args> { | |||
list<string> r = !foldl( | |||
EmptyStrList, args, l, arg, StrListAppend<l, TupleConvertFromStorage<arg, !size(l)>.r>.r); | |||
} | |||
class ApplyTupleConstBuilderCall<list<MgbAttrWrapper> args> { | |||
list<string> r = !foldl( | |||
EmptyStrList, args, l, arg, StrListAppend<l, TupleConstBuilderCall<arg, !size(l)>.r>.r); | |||
} | |||
class MgbTupleAttr<list<MgbAttrWrapper> args>: | |||
HashableAttr<"std::tuple<" # StrJoin<!foreach(i, args, i.underlyingType)>.result # ">"> { | |||
let storageType = "::mlir::ArrayAttr"; | |||
let convertFromStorage = "std::make_tuple(" # StrJoin<ApplyTupleConvertFromStorage<args>.r>.result # ")"; | |||
let constBuilderCall = "$_builder.getArrayAttr({" # StrJoin<ApplyTupleConstBuilderCall<args>.r>.result # "})"; | |||
} | |||
// -- enum types | |||
class MgbEnumAttr<string namespace, string enumName, list<string> members>: | |||
HashableAttr<namespace # "::" # enumName>, MgbEnumAttrMixin<namespace, enumName, members> { | |||
let storageType = "::mlir::IntegerAttr"; | |||
let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; | |||
let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))"; | |||
let hashFunction = "mgb::enumhash()($0)"; | |||
} | |||
class MgbEnumAliasAttr<string namespace, string enumName, MgbEnumAttr base>: | |||
MgbEnumAttr<namespace, enumName, base.enumMembers>, MgbAliasAttrMixin<base>; | |||
// -- other types | |||
def MgbDTypeAttr: HashableAttr<"::megdnn::DType"> { | |||
let storageType = "::mlir::IntegerAttr"; | |||
let convertFromStorage = underlyingType # "::from_enum(static_cast<::megdnn::DTypeEnum>($_self.getInt()))"; | |||
let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0.enumv()))"; | |||
let hashFunction = "mgb::hash($0.handle())"; | |||
} | |||
def MgbCompNodeAttr: HashableAttr<"::mgb::CompNode"> { | |||
let storageType = "::mlir::StringAttr"; | |||
let convertFromStorage = underlyingType # "::load($_self.getValue().str())"; | |||
let constBuilderCall = "$_builder.getStringAttr($0.to_string_logical())"; | |||
} | |||
def MgbTensorShapeAttr: HashableAttr<"::megdnn::TensorShape"> { | |||
let storageType = "::mlir::ArrayAttr"; | |||
let hashFunction = "mgb::PODHash<size_t>::perform($0.shape, $0.ndim)"; | |||
let cmpFunction = "!$0.eq_shape($1)"; | |||
defvar elemInst = MgbSizeTAddr; | |||
let convertFromStorage = | |||
"[&] {\n" | |||
" " # underlyingType # " ret;\n" | |||
" std::for_each($_self.begin(), $_self.end(), [&ret](auto&& i) {\n" | |||
" ret[ret.ndim ++] = " # !subst("$_self", "i.template cast<"# elemInst.storageType #">()", "" # elemInst.convertFromStorage) # ";\n" | |||
" });\n" | |||
" return ret;}()"; | |||
let constBuilderCall = | |||
"[&] {\n" | |||
" std::vector<mlir::Attribute> ret;\n" | |||
" for (size_t i = 0; i < $0.ndim; ++ i) {\n" | |||
" ret.push_back(\n" | |||
" " # !subst("$0", "$0[i]", "" # elemInst.constBuilderCall) # "\n" | |||
" );\n" | |||
" }\n" | |||
" return $_builder.getArrayAttr(ret);" | |||
"}()"; | |||
} | |||
class MgbDefaultValuedAttr<MgbAttrWrapper attr, string value>: | |||
DefaultValuedAttr<attr, value>, MgbAttrWrapperBase<attr.underlyingType> { | |||
// Note: this class is similar to DefaultValuedAttr but with extra | |||
// meta informations which are used by mgb dialect tblgen, so this | |||
// has to be kept up to date with class MgbAttrWrapperMixin | |||
let recursionDepth = attr.recursionDepth; | |||
} | |||
// -- dnn params | |||
class MgbParamBase<string className> { | |||
string paramType = className; | |||
string fullName = "::megdnn::param::" # paramType; | |||
dag fields = ?; | |||
} | |||
class MgbPackedParamBase<string className, string accessor>: | |||
MgbParamBase<className> { | |||
string paramAccessor = accessor; | |||
} | |||
// -- mgb ops | |||
class MgbHashableOpMixin { | |||
string hashFunction = ?; | |||
string cmpFunction = ?; | |||
} | |||
class MgbOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits=[]>: | |||
Op<Mgb_Dialect, mnemonic, traits> { | |||
dag inputs = (ins); | |||
dag extraArguments = (ins); | |||
// TODO: remove it | |||
code extraOpdefDecl = ?; | |||
let arguments = !con( | |||
!foldl(inputs, params, args, param, !con(args, param.fields)), | |||
extraArguments); | |||
list<MgbParamBase> dnnParams = params; | |||
} | |||
class MgbHashableOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits=[]>: | |||
MgbOp<mnemonic, params, traits>, MgbHashableOpMixin; | |||
#endif // MGB_BASE |
@@ -0,0 +1,240 @@ | |||
/** | |||
* \file src/core/include/megbrain/ir/ops.td | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#ifndef MGB_OPS | |||
#define MGB_OPS | |||
include "base.td" | |||
include "param_defs.td" | |||
include "mlir/Interfaces/SideEffectInterfaces.td" | |||
def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> { | |||
let inputs = (ins Variadic<AnyType>:$input); | |||
let results = (outs AnyType); | |||
} | |||
def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>; | |||
def TypeCvt: MgbHashableOp<"TypeCvt", [], [NoSideEffect]> { | |||
let inputs = (ins AnyType:$inputs); | |||
let extraArguments = (ins | |||
TypeAttr:$idtype, | |||
MgbDTypeAttr:$dtype | |||
); | |||
let results = (outs AnyType); | |||
} | |||
def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam]>; | |||
def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam]>; | |||
def Dot: MgbHashableOp<"Dot", [EmptyParam]>; | |||
def SVD: MgbHashableOp<"SVD", [SVDParam]>; | |||
def Convolution : MgbHashableOp<"Convolution", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; | |||
def ConvolutionBackwardData: MgbHashableOp<"ConvolutionBackwardData", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; | |||
def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>; | |||
def Pooling: MgbHashableOp<"Pooling", [PoolingParam]>; | |||
def AdaptivePooling : MgbHashableOp<"AdaptivePooling", [AdaptivePoolingParam]>; | |||
def ROIPooling: MgbHashableOp<"ROIPooling", [ROIPoolingParam]>; | |||
def ConvBias : MgbHashableOp<"ConvBias", [ConvBiasParam, ExecutionPolicyParamBase<"policy">]> { | |||
let extraArguments = (ins | |||
MgbDTypeAttr:$dtype | |||
); | |||
} | |||
def BatchConvBias : MgbHashableOp<"BatchConvBias", [BatchConvBiasParam, ExecutionPolicyParamBase<"policy">]> { | |||
let extraArguments = (ins | |||
MgbDTypeAttr:$dtype | |||
); | |||
} | |||
def BatchNorm : MgbHashableOp<"BatchNorm", [BNParam]>; | |||
def ROIAlign: MgbHashableOp<"ROIAlign", [ROIAlignParam]>; | |||
def WarpPerspective: MgbHashableOp<"WarpPerspective", [WarpPerspectiveParam]>; | |||
def Remap: MgbHashableOp<"Remap", [RemapParam]>; | |||
def IndexingOneHot: MgbHashableOp<"IndexingOneHot", [AxisParam]>; | |||
def IndexingSetOneHot: MgbHashableOp<"IndexingSetOneHot", [AxisParam]>; | |||
def Copy: MgbHashableOp<"Copy"> { | |||
let extraArguments = (ins | |||
MgbCompNodeAttr:$comp_node | |||
); | |||
} | |||
def Argsort: MgbHashableOp<"Argsort", [ArgsortParam]>; | |||
def Argmax : MgbHashableOp<"Argmax", [AxisParam]>; | |||
def Argmin : MgbHashableOp<"Argmin", [AxisParam]>; | |||
def CondTake : MgbHashableOp<"CondTake">; | |||
def TopK: MgbHashableOp<"TopK", [TopKParam]>; | |||
def NvOf: MgbHashableOp<"NvOf", [NvOfParam]>; | |||
def UniformRNG: MgbHashableOp<"UniformRNG", [UniformRNGParam]> { | |||
let hashFunction = [{return mgb::hash($_self.dyn_typeinfo());}]; | |||
let cmpFunction = [{return true;}]; | |||
} | |||
def GaussianRNG: MgbHashableOp<"GaussianRNG", [GaussianRNGParam]> { | |||
let hashFunction = [{ | |||
return mgb::hash_pair_combine( | |||
mgb::hash($_self.dyn_typeinfo()), | |||
mgb::hash_pair_combine(mgb::hash($_self.mean), mgb::hash($_self.std))); | |||
}]; | |||
let cmpFunction = [{return $0.mean == $1.mean && $0.std == $1.std;}]; | |||
} | |||
def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> { | |||
let extraArguments = (ins | |||
MgbCompNodeAttr:$comp_node | |||
); | |||
} | |||
def Eye: MgbHashableOp<"Eye", [EyeParam]> { | |||
let extraArguments = (ins | |||
MgbCompNodeAttr:$comp_node | |||
); | |||
} | |||
def GetVarShape : MgbHashableOp<"GetVarShape">; | |||
def Concat: MgbHashableOp<"Concat", [AxisParam]> { | |||
let extraArguments = (ins | |||
MgbCompNodeAttr:$comp_node | |||
); | |||
} | |||
def Broadcast : MgbHashableOp<"Broadcast", [EmptyParam]>; | |||
def Identity: MgbHashableOp<"Identity">; | |||
def CollectiveComm : MgbHashableOp<"CollectiveComm", [CollectiveCommParam]> { | |||
let extraArguments = (ins | |||
MgbStringAttr:$key, | |||
MgbUI32Attr:$nr_devices, | |||
MgbUI32Attr:$rank, | |||
MgbBoolAttr:$is_root, | |||
MgbBoolAttr:$local_grad, | |||
MgbStringAttr:$addr, | |||
MgbUI32Attr:$port, | |||
MgbDTypeAttr:$dtype, | |||
MgbStringAttr:$backend, | |||
MgbStringAttr:$comp_node | |||
); | |||
} | |||
def RemoteSend : MgbHashableOp<"RemoteSend"> { | |||
let extraArguments = (ins | |||
MgbStringAttr:$key, | |||
MgbStringAttr:$addr, | |||
MgbUI32Attr:$port, | |||
MgbUI32Attr:$rank_to | |||
); | |||
} | |||
def RemoteRecv : MgbHashableOp<"RemoteRecv"> { | |||
let extraArguments = (ins | |||
MgbStringAttr:$key, | |||
MgbStringAttr:$addr, | |||
MgbUI32Attr:$port, | |||
MgbUI32Attr:$rank_from, | |||
MgbCompNodeAttr:$cn, | |||
MgbTensorShapeAttr:$shape, | |||
MgbDTypeAttr:$dtype | |||
); | |||
} | |||
def NMSKeep : MgbHashableOp<"NMSKeep"> { | |||
let extraArguments = (ins | |||
MgbF32Attr:$iou_thresh, | |||
MgbUI32Attr:$max_output | |||
); | |||
} | |||
def ParamPackSplit : MgbHashableOp<"ParamPackSplit"> { | |||
let extraArguments = (ins | |||
MgbArrayAttr<MgbI32Attr>:$offsets, | |||
MgbArrayAttr<MgbArrayAttr<MgbSizeTAddr>>:$shapes | |||
); | |||
} | |||
def ParamPackConcat : MgbHashableOp<"ParamPackConcat"> { | |||
let extraArguments = (ins | |||
MgbArrayAttr<MgbI32Attr>:$offsets | |||
); | |||
} | |||
def Dimshuffle: MgbHashableOp<"Dimshuffle"> { | |||
let inputs = (ins AnyMemRef:$input); | |||
let extraArguments = (ins MgbArrayAttr<MgbI32Attr>:$pattern); | |||
let results = (outs AnyMemRef); | |||
} | |||
def Reshape: MgbHashableOp<"Reshape", [OptionalAxisV1Param]>; | |||
// TODO: merge Add/Remove Axis into AxisAddRemove as megbrain? | |||
def AddAxis: MgbHashableOp<"AddAxis"> { | |||
let extraArguments = (ins | |||
MgbArrayAttr<MgbI32Attr>:$axis | |||
); | |||
} | |||
def RemoveAxis: MgbHashableOp<"RemoveAxis"> { | |||
let extraArguments = (ins | |||
MgbArrayAttr<MgbI32Attr>:$axis | |||
); | |||
} | |||
class FancyIndexingBase<string name>: MgbHashableOp<name> { | |||
let extraArguments = (ins | |||
MgbArrayAttr<MgbTupleAttr< | |||
[MgbI8Attr, MgbBoolAttr, MgbBoolAttr, MgbBoolAttr, MgbBoolAttr]>>:$items | |||
); | |||
} | |||
def Subtensor: FancyIndexingBase<"Subtensor">; | |||
def SetSubtensor: FancyIndexingBase<"SetSubtensor">; | |||
def IncrSubtensor: FancyIndexingBase<"IncrSubtensor">; | |||
def IndexingMultiAxisVec: FancyIndexingBase<"IndexingMultiAxisVec">; | |||
def IndexingSetMultiAxisVec: FancyIndexingBase<"IndexingSetMultiAxisVec">; | |||
def IndexingIncrMultiAxisVec: FancyIndexingBase<"IndexingIncrMultiAxisVec">; | |||
def MeshIndexing: FancyIndexingBase<"MeshIndexing">; | |||
def IncrMeshIndexing: FancyIndexingBase<"IncrMeshIndexing">; | |||
def SetMeshIndexing: FancyIndexingBase<"SetMeshIndexing">; | |||
def BatchedMeshIndexing: FancyIndexingBase<"BatchedMeshIndexing">; | |||
def BatchedIncrMeshIndexing: FancyIndexingBase<"BatchedIncrMeshIndexing">; | |||
def BatchedSetMeshIndexing: FancyIndexingBase<"BatchedSetMeshIndexing">; | |||
def FakeQuant: MgbHashableOp<"FakeQuant", [FakeQuantParam]>; | |||
def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypeParam]> { | |||
let extraArguments = (ins | |||
MgbDTypeAttr:$dtype | |||
); | |||
} | |||
#endif // MGB_OPS |
@@ -47,3 +47,4 @@ pushd MegRay/third_party >/dev/null | |||
popd >/dev/null | |||
git submodule update --init pybind11 | |||
git submodule update --init llvm-project |