Browse Source

feat(imperative): auto generated opdef header and python binding

GitOrigin-RevId: d2f22ad5fe
release-1.2
Megvii Engine Team 4 years ago
parent
commit
69e3e32240
67 changed files with 2191 additions and 2621 deletions
  1. +6
    -8
      CMakeLists.txt
  2. +150
    -0
      dnn/scripts/gen_tablegen.py
  3. +3
    -30
      imperative/CMakeLists.txt
  4. +10
    -23
      imperative/python/megengine/core/autodiff/builtin_op_utils.py
  5. +0
    -8
      imperative/python/megengine/core/ops/_internal/__init__.py
  6. +0
    -10
      imperative/python/megengine/core/ops/_internal/all_ops.py
  7. +0
    -939
      imperative/python/megengine/core/ops/_internal/enum36.py
  8. +0
    -94
      imperative/python/megengine/core/ops/_internal/helper.py
  9. +0
    -194
      imperative/python/megengine/core/ops/_internal/misc_ops.py
  10. +1
    -14
      imperative/python/megengine/core/ops/builtin/__init__.py
  11. +2
    -5
      imperative/python/megengine/core/tensor/tensor_wrapper.py
  12. +1
    -1
      imperative/python/megengine/core/tensor/utils.py
  13. +24
    -22
      imperative/python/megengine/distributed/functional.py
  14. +0
    -20
      imperative/python/megengine/functional/elemwise.py
  15. +7
    -9
      imperative/python/megengine/functional/math.py
  16. +10
    -19
      imperative/python/megengine/functional/nn.py
  17. +6
    -6
      imperative/python/megengine/functional/quantized.py
  18. +2
    -9
      imperative/python/megengine/functional/tensor.py
  19. +3
    -9
      imperative/python/megengine/jit/tracing.py
  20. +6
    -17
      imperative/python/megengine/module/conv.py
  21. +0
    -1
      imperative/python/megengine/module/elemwise.py
  22. +2
    -2
      imperative/python/megengine/module/qat/conv.py
  23. +1
    -4
      imperative/python/megengine/module/quantized/elemwise.py
  24. +2
    -2
      imperative/python/megengine/utils/profiler.py
  25. +1
    -3
      imperative/python/src/graph_rt.cpp
  26. +0
    -3
      imperative/python/src/imperative_rt.cpp
  27. +13
    -179
      imperative/python/src/ops.cpp
  28. +2
    -2
      imperative/python/test/unit/core/test_dtype_quant.py
  29. +17
    -18
      imperative/python/test/unit/core/test_indexing_op.py
  30. +0
    -320
      imperative/python/tools/gen_ops.py
  31. +0
    -40
      imperative/python/tools/ops.tpl.py
  32. +9
    -1
      imperative/src/impl/op_def.cpp
  33. +1
    -1
      imperative/src/impl/op_trait.cpp
  34. +42
    -1
      imperative/src/impl/op_trait.h
  35. +46
    -0
      imperative/src/impl/ops/autogen.cpp
  36. +6
    -11
      imperative/src/impl/ops/batch_norm.cpp
  37. +3
    -3
      imperative/src/impl/ops/broadcast.cpp
  38. +3
    -32
      imperative/src/impl/ops/collective_comm.cpp
  39. +1
    -4
      imperative/src/impl/ops/cond_take.cpp
  40. +4
    -4
      imperative/src/impl/ops/elemwise.cpp
  41. +1
    -41
      imperative/src/impl/ops/io_remote.cpp
  42. +1
    -3
      imperative/src/impl/ops/nms.cpp
  43. +630
    -0
      imperative/src/impl/ops/specializations.cpp
  44. +1
    -5
      imperative/src/impl/ops/tensor_manip.cpp
  45. +7
    -7
      imperative/src/impl/profiler.cpp
  46. +3
    -3
      imperative/src/impl/proxy_graph.cpp
  47. +7
    -16
      imperative/src/include/megbrain/imperative/op_def.h
  48. +8
    -15
      imperative/src/include/megbrain/imperative/ops/autogen.h
  49. +0
    -70
      imperative/src/include/megbrain/imperative/ops/batch_norm.h
  50. +0
    -35
      imperative/src/include/megbrain/imperative/ops/broadcast.h
  51. +0
    -69
      imperative/src/include/megbrain/imperative/ops/collective_comm.h
  52. +0
    -42
      imperative/src/include/megbrain/imperative/ops/elemwise.h
  53. +0
    -77
      imperative/src/include/megbrain/imperative/ops/io_remote.h
  54. +0
    -41
      imperative/src/include/megbrain/imperative/ops/nms.h
  55. +0
    -99
      imperative/src/include/megbrain/imperative/ops/tensor_manip.h
  56. +8
    -8
      imperative/src/test/backward_graph.cpp
  57. +6
    -5
      imperative/src/test/collective_comm.cpp
  58. +1
    -1
      imperative/src/test/cond_take.cpp
  59. +1
    -1
      imperative/src/test/helper.cpp
  60. +9
    -14
      imperative/src/test/io_remote.cpp
  61. +14
    -0
      imperative/tablegen/CMakeLists.txt
  62. +383
    -0
      imperative/tablegen/autogen.cpp
  63. +228
    -0
      imperative/tablegen/helper.h
  64. +1
    -1
      imperative/test/CMakeLists.txt
  65. +257
    -0
      src/core/include/megbrain/ir/base.td
  66. +240
    -0
      src/core/include/megbrain/ir/ops.td
  67. +1
    -0
      third_party/prepare.sh

+ 6
- 8
CMakeLists.txt View File

@@ -230,6 +230,10 @@ endif()
# FIXME At present, there are some conflicts between the LLVM that halide
# depends on and the LLVM that MLIR depends on. Should be fixed in subsequent
# versions.
if(MGE_BUILD_IMPERATIVE_RT)
set(MGE_WITH_HALIDE OFF)
message(WARNING "cannot use HALIDE when building IMPERATIVE_RT")
endif()
if(MGE_WITH_JIT_MLIR)
if(MGE_WITH_HALIDE)
message(FATAL_ERROR "please set MGE_WITH_HALIDE to OFF with MGE_WITH_JIT_MLIR enabled")
@@ -310,7 +314,7 @@ if(MGE_INFERENCE_ONLY)
set(MGE_BUILD_IMPERATIVE_RT OFF)
endif()

if(MGE_WITH_JIT_MLIR)
if(MGE_WITH_JIT_MLIR OR MGE_BUILD_IMPERATIVE_RT)
include(cmake/llvm-project.cmake)
endif()

@@ -750,7 +754,7 @@ target_include_directories(mgb_opr_param_defs
add_dependencies(mgb_opr_param_defs _mgb_opr_param_defs)
install(TARGETS mgb_opr_param_defs EXPORT ${MGE_EXPORT_TARGETS})

if(MGE_WITH_JIT_MLIR)
if(MGE_WITH_JIT_MLIR OR MGE_BUILD_IMPERATIVE_RT)
# generate param_defs.td
set(MGE_GENFILE_DIR ${PROJECT_BINARY_DIR}/src/genfiles)
set(MGE_GEN_IR_DIR ${PROJECT_BINARY_DIR}/src/core/include/megbrain/ir)
@@ -800,12 +804,6 @@ if(TARGET _imperative_rt)
COMMAND ${CMAKE_COMMAND} -E create_symlink
${CMAKE_CURRENT_BINARY_DIR}/imperative/python/${PACKAGE_NAME}/core/$<TARGET_FILE_NAME:${MODULE_NAME}>
${CMAKE_CURRENT_SOURCE_DIR}/imperative/python/${PACKAGE_NAME}/core/$<TARGET_FILE_NAME:${MODULE_NAME}>
COMMAND ${CMAKE_COMMAND} -E create_symlink
${CMAKE_CURRENT_BINARY_DIR}/imperative/python/${PACKAGE_NAME}/core/ops/_internal/generated_ops.py
${CMAKE_CURRENT_SOURCE_DIR}/imperative/python/${PACKAGE_NAME}/core/ops/_internal/generated_ops.py
COMMAND ${CMAKE_COMMAND} -E create_symlink
${CMAKE_CURRENT_BINARY_DIR}/imperative/python/${PACKAGE_NAME}/core/ops/_internal/param_defs.py
${CMAKE_CURRENT_SOURCE_DIR}/imperative/python/${PACKAGE_NAME}/core/ops/_internal/param_defs.py
DEPENDS _imperative_rt
VERBATIM
)


+ 150
- 0
dnn/scripts/gen_tablegen.py View File

@@ -0,0 +1,150 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import collections
import textwrap
import os
import hashlib
import struct
import io

from gen_param_defs import member_defs, ParamDef, IndentWriterBase


class ConverterWriter(IndentWriterBase):
_skip_current_param = False
_last_param = None
_current_tparams = None
_packed = None
_const = None

def __call__(self, fout, defs):
super().__call__(fout)
self._write("// %s", self._get_header())
self._write("#ifndef MGB_PARAM")
self._write("#define MGB_PARAM")
self._process(defs)
self._write("#endif // MGB_PARAM")

def _ctype2attr(self, ctype, value):
if ctype == 'uint32_t':
return 'MgbUI32Attr', value
if ctype == 'uint64_t':
return 'MgbUI64Attr', value
if ctype == 'int32_t':
return 'MgbI32Attr', value
if ctype == 'float':
return 'MgbF32Attr', value
if ctype == 'double':
return 'MgbF64Attr', value
if ctype == 'bool':
return 'MgbBoolAttr', value
if ctype == 'DTypeEnum':
self._packed = False
return 'MgbDTypeAttr', 'megdnn::DType::from_enum(megdnn::{})'.format(value)
raise RuntimeError("unknown ctype")

def _on_param_begin(self, p):
self._last_param = p
if p.is_legacy:
self._skip_current_param = True
return
self._packed = True
self._current_tparams = []
self._const = set()

def _on_param_end(self, p):
if self._skip_current_param:
self._skip_current_param = False
return
if self._packed:
self._write("class {0}ParamBase<string accessor> : MgbPackedParamBase<\"{0}\", accessor> {{".format(p.name), indent=1)
else:
self._write("def {0}Param: MgbParamBase<\"{0}\"> {{".format(p.name), indent=1)
self._write("let fields = (ins", indent=1)
self._write(",\n{}".format(self._cur_indent).join(self._current_tparams))
self._write(");", indent=-1)
self._write("}\n", indent=-1)
if self._packed:
self._write("def {0}Param : {0}ParamBase<\"param\">;\n".format(p.name))
self._current_tparams = None
self._packed = None
self._const = None

def _wrapped_with_default_value(self, attr, default):
return 'MgbDefaultValuedAttr<{}, \"{}\">'.format(attr, default)

def _on_member_enum(self, e):
p = self._last_param

# Note: always generate llvm Record def for enum attribute even it was not
# directly used by any operator, or other enum couldn't alias to this enum
td_class = "{}{}".format(p.name, e.name)
fullname = "::megdnn::param::{}".format(p.name)
enum_def = "MgbEnumAttr<\"{}\", \"{}\", [".format(fullname, e.name)
def format(v):
return '\"{}\"'.format(str(v))
enum_def += ','.join(format(i) for i in e.members)
enum_def += "]>"
self._write("def {} : {};".format(td_class, enum_def))

if self._skip_current_param:
return

# wrapped with default value
default_val = "static_cast<{}::{}>({})".format(fullname, e.name, e.default)
wrapped = self._wrapped_with_default_value(td_class, default_val)

self._current_tparams.append("{}:${}".format(wrapped, e.name_field))

def _on_member_enum_alias(self, e):
p = self._last_param
if self._skip_current_param:
return

# write enum attr def
td_class = "{}{}".format(p.name, e.name)
fullname = "::megdnn::param::{}".format(p.name)
base_td_class = "{}{}".format(e.src_class, e.src_name)
enum_def = "MgbEnumAliasAttr<\"{}\", \"{}\", {}>".format(fullname, e.name, base_td_class)
self._write("def {} : {};".format(td_class, enum_def))

# wrapped with default value
default_val = "static_cast<{}::{}>({})".format(fullname, e.name, e.get_default())
wrapped = self._wrapped_with_default_value(td_class, default_val)

self._current_tparams.append("{}:${}".format(wrapped, e.name_field))


def _on_member_field(self, f):
if self._skip_current_param:
return
attr, value = self._ctype2attr(f.dtype.cname, str(f.default))
if str(value) in self._const:
value = '::megdnn::param::{}::{}'.format(self._last_param.name, value)
wrapped = self._wrapped_with_default_value(attr, value)
self._current_tparams.append("{}:${}".format(wrapped, f.name))

def _on_const_field(self, f):
self._const.add(str(f.name))

def main():
parser = argparse.ArgumentParser('generate op param tablegen file')
parser.add_argument('input')
parser.add_argument('output')
args = parser.parse_args()

with open(args.input) as fin:
inputs = fin.read()
exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc})
input_hash = hashlib.sha256()
input_hash.update(inputs.encode(encoding='UTF-8'))
input_hash = input_hash.hexdigest()

writer = ConverterWriter()
with open(args.output, 'w') as fout:
writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs)

if __name__ == "__main__":
main()

+ 3
- 30
imperative/CMakeLists.txt View File

@@ -8,9 +8,7 @@ file(GLOB_RECURSE SRCS src/impl/*.cpp src/include/*.h python/src/*.cpp python/sr

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMGB_WITH_IMPERATIVE=1")

file(GLOB_RECURSE OPR_DECL_SRCS "${PROJECT_SOURCE_DIR}/src/**/*.oprdecl")
file(GLOB_RECURSE PYTHON_SRCS python/${PACKAGE_NAME}/*.py)
list(REMOVE_ITEM PYTHON_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/python/megengine/core/ops/_internal/generated_ops.py ${CMAKE_CURRENT_SOURCE_DIR}/python/megengine/core/ops/_internal/param_defs.py)
file(GLOB_RECURSE ALL_HEADERS src/cpp/megbrain_pubapi.h
${PROJECT_SOURCE_DIR}/src/core/include/*
${PROJECT_SOURCE_DIR}/src/opr/include/*
@@ -19,33 +17,8 @@ file(GLOB_RECURSE ALL_HEADERS src/cpp/megbrain_pubapi.h
${PROJECT_SOURCE_DIR}/dnn/include/*)

set(MEGENGINE_DIR ${CMAKE_CURRENT_BINARY_DIR}/python/)
set(GEN_OPS_DIR ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/ops/_internal)
file(MAKE_DIRECTORY ${GEN_OPS_DIR})
set(GEN_OPS_FILE ${GEN_OPS_DIR}/generated_ops.py)
set(GEN_OP_PARAMS_FILE ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/ops/_internal/param_defs.py)
set(GEN_OP_PARAMS_TEMPLATE ${CMAKE_CURRENT_SOURCE_DIR}/python/tools/ops.tpl.py)

##################### generate python opr_param_defs.py ##############

file(COPY ${PROJECT_SOURCE_DIR}/dnn/scripts/opr_param_defs.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
file(READ ${PROJECT_SOURCE_DIR}/tools/param_defs/mgb_opr_param_defs.py CONTENTS)
file(APPEND ${CMAKE_CURRENT_BINARY_DIR}/opr_param_defs.py ${CONTENTS})

add_custom_command(
OUTPUT ${GEN_OPS_FILE}
COMMAND ${CMAKE_COMMAND} -E touch ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/${MODULE_NAME}.so ${GEN_OPS_FILE} ${GEN_OP_PARAMS_FILE}
COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/python/${PACKAGE_NAME} ${MEGENGINE_DIR}/${PACKAGE_NAME}
COMMAND ${CMAKE_COMMAND} -E remove -f ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/${MODULE_NAME}.so ${GEN_OPS_FILE} ${GEN_OP_PARAMS_FILE}
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/python/tools/gen_ops.py ${OPR_DECL_SRCS} -o ${GEN_OPS_FILE}
COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/python/test ${MEGENGINE_DIR}/${PACKAGE_NAME}/test
COMMAND ${PYTHON_EXECUTABLE} ${PROJECT_SOURCE_DIR}/dnn/scripts/gen_param_defs.py -t py --imperative ${CMAKE_CURRENT_BINARY_DIR}/opr_param_defs.py ${GEN_OP_PARAMS_FILE}
DEPENDS ${OPR_DECL_SRCS} ${PYTHON_SRCS} ${ALL_HEADERS} ${GEN_OP_PARAMS_TEMPLATE}
VERBATIM
)

add_custom_target(gen_opr_py DEPENDS ${GEN_OPS_FILE})

##################### end of opdef generation #########################
add_subdirectory(tablegen)

add_custom_target(_version_ld SOURCES ${MGE_VERSION_SCRIPT})

@@ -73,7 +46,7 @@ else()
endif()
endif()

target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR})
target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR})
target_compile_definitions(${MODULE_NAME} PRIVATE MODULE_NAME=${MODULE_NAME})
target_compile_options(${MODULE_NAME} PRIVATE -Wno-unused-parameter)
if(CXX_SUPPORT_WCLASS_MEMACCESS)
@@ -87,7 +60,7 @@ if (APPLE OR MSVC OR WIN32)
message(VERBOSE "overwriting SUFFIX at macos and windows before config by set_target_properties")
pybind11_extension(${MODULE_NAME})
endif()
add_dependencies(${MODULE_NAME} gen_opr_py _version_ld)
add_dependencies(${MODULE_NAME} mgb_opdef _version_ld)

if(MGE_WITH_TEST AND MGE_ENABLE_RTTI)
add_subdirectory(test)


+ 10
- 23
imperative/python/megengine/core/autodiff/builtin_op_utils.py View File

@@ -19,7 +19,6 @@ from ..ops.builtin import (
IndexingMultiAxisVec,
IndexingSetMultiAxisVec,
OpDef,
OprAttr,
Reduce,
Reshape,
SetSubtensor,
@@ -31,8 +30,6 @@ from ..tensor.function import Function
from ..tensor.tensor import Tensor
from ..tensor.tensor_wrapper import TensorWrapper

_reduce_sum_param = Reduce(mode="SUM").to_c().param[0]


@functools.singledispatch
def builtin_op_get_backward_fn(op: OpDef, inputs, outputs, input_requires_grad):
@@ -41,17 +38,18 @@ def builtin_op_get_backward_fn(op: OpDef, inputs, outputs, input_requires_grad):

@builtin_op_get_backward_fn.register(OpDef)
def _(op: OpDef, inputs, outputs, input_requires_grad):
if isinstance(op, OprAttr):
grad_fn = _oprAttr_grad_fn.get(op.type, None)
if grad_fn is None:
if op.type == Reduce.name and op.param[0] == _reduce_sum_param:
grad_fn = reduce_sum_grad_fn
else:
grad_fn = default_grad_fn
if isinstance(op, Reshape):
grad_fn = reshape_grad_fn
elif isinstance(op, Subtensor):
grad_fn = subtensor_grad_fn
elif isinstance(op, IndexingMultiAxisVec):
grad_fn = indexingMultiAxisVec_grad_fn
elif isinstance(op, Broadcast) or (
isinstance(op, Elemwise) and op.mode == Elemwise.Mode.ADD
):
grad_fn = elemwise_add_grad_fn
elif isinstance(op, Reduce) and op.mode.name == "SUM":
grad_fn = reduce_sum_grad_fn
else:
grad_fn = default_grad_fn
return grad_fn(op, inputs, outputs, input_requires_grad)
@@ -152,9 +150,7 @@ def reshape_grad_fn(op, inputs, outputs, input_requires_grad):

# override for Subtensor
def subtensor_grad_fn(op, inputs, outputs, input_requires_grad):
grad_op = OprAttr()
grad_op.type = SetSubtensor.name
grad_op.param = op.param
grad_op = SetSubtensor(op.items)

input_shape = get_shape(inputs[0])
params = inputs[1:]
@@ -175,9 +171,7 @@ def subtensor_grad_fn(op, inputs, outputs, input_requires_grad):

# override for IndexingMultiAxisVec
def indexingMultiAxisVec_grad_fn(op, inputs, outputs, input_requires_grad):
grad_op = OprAttr()
grad_op.type = IndexingSetMultiAxisVec.name
grad_op.param = op.param
grad_op = IndexingSetMultiAxisVec(op.items)

input_shape = get_shape(inputs[0])
params = inputs[1:]
@@ -209,10 +203,3 @@ def reduce_sum_grad_fn(op, inputs, outputs, input_requires_grad):
return (broadcast_to(dy, input_shape) if input_requires_grad[0] else None,)

return backward, [True]


_oprAttr_grad_fn = {
Reshape.name: reshape_grad_fn,
Subtensor.name: subtensor_grad_fn,
IndexingMultiAxisVec.name: indexingMultiAxisVec_grad_fn,
}

+ 0
- 8
imperative/python/megengine/core/ops/_internal/__init__.py View File

@@ -1,8 +0,0 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

+ 0
- 10
imperative/python/megengine/core/ops/_internal/all_ops.py View File

@@ -1,10 +0,0 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .generated_ops import *
from .misc_ops import *

+ 0
- 939
imperative/python/megengine/core/ops/_internal/enum36.py View File

@@ -1,939 +0,0 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import sys
from functools import reduce
from operator import or_ as _or_
from types import DynamicClassAttribute, MappingProxyType

# try _collections first to reduce startup cost
try:
from _collections import OrderedDict
except ImportError:
from collections import OrderedDict


__all__ = [
"EnumMeta",
"Enum",
"IntEnum",
"Flag",
"IntFlag",
"auto",
"unique",
]


def _is_descriptor(obj):
"""Returns True if obj is a descriptor, False otherwise."""
return (
hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__")
)


def _is_dunder(name):
"""Returns True if a __dunder__ name, False otherwise."""
return (
name[:2] == name[-2:] == "__"
and name[2:3] != "_"
and name[-3:-2] != "_"
and len(name) > 4
)


def _is_sunder(name):
"""Returns True if a _sunder_ name, False otherwise."""
return (
name[0] == name[-1] == "_"
and name[1:2] != "_"
and name[-2:-1] != "_"
and len(name) > 2
)


def _make_class_unpicklable(cls):
"""Make the given class un-picklable."""

def _break_on_call_reduce(self, proto):
raise TypeError("%r cannot be pickled" % self)

cls.__reduce_ex__ = _break_on_call_reduce
cls.__module__ = "<unknown>"


_auto_null = object()


class auto:
"""
Instances are replaced with an appropriate value in Enum class suites.
"""

value = _auto_null


class _EnumDict(dict):
"""
Track enum member order and ensure member names are not reused.

EnumMeta will use the names found in self._member_names as the
enumeration member names.

"""

def __init__(self):
super().__init__()
self._member_names = []
self._last_values = []

def __setitem__(self, key, value):
"""
Changes anything not dundered or not a descriptor.

If an enum member name is used twice, an error is raised; duplicate
values are not checked for.

Single underscore (sunder) names are reserved.

"""
if _is_sunder(key):
if key not in (
"_order_",
"_create_pseudo_member_",
"_generate_next_value_",
"_missing_",
):
raise ValueError("_names_ are reserved for future Enum use")
if key == "_generate_next_value_":
setattr(self, "_generate_next_value", value)
elif _is_dunder(key):
if key == "__order__":
key = "_order_"
elif key in self._member_names:
# descriptor overwriting an enum?
raise TypeError("Attempted to reuse key: %r" % key)
elif not _is_descriptor(value):
if key in self:
# enum overwriting a descriptor?
raise TypeError("%r already defined as: %r" % (key, self[key]))
if isinstance(value, auto):
if value.value == _auto_null:
value.value = self._generate_next_value(
key, 1, len(self._member_names), self._last_values[:]
)
value = value.value
self._member_names.append(key)
self._last_values.append(value)
super().__setitem__(key, value)


# Dummy value for Enum as EnumMeta explicitly checks for it, but of course
# until EnumMeta finishes running the first time the Enum class doesn't exist.
# This is also why there are checks in EnumMeta like `if Enum is not None`
Enum = None


class EnumMeta(type):
"""Metaclass for Enum"""

@classmethod
def __prepare__(metacls, cls, bases):
# create the namespace dict
enum_dict = _EnumDict()
# inherit previous flags and _generate_next_value_ function
member_type, first_enum = metacls._get_mixins_(bases)
if first_enum is not None:
enum_dict["_generate_next_value_"] = getattr(
first_enum, "_generate_next_value_", None
)
return enum_dict

def __new__(metacls, cls, bases, classdict):
# an Enum class is final once enumeration items have been defined; it
# cannot be mixed with other types (int, float, etc.) if it has an
# inherited __new__ unless a new __new__ is defined (or the resulting
# class will fail).
member_type, first_enum = metacls._get_mixins_(bases)
__new__, save_new, use_args = metacls._find_new_(
classdict, member_type, first_enum
)

# save enum items into separate mapping so they don't get baked into
# the new class
enum_members = {k: classdict[k] for k in classdict._member_names}
for name in classdict._member_names:
del classdict[name]

# adjust the sunders
_order_ = classdict.pop("_order_", None)

# check for illegal enum names (any others?)
invalid_names = set(enum_members) & {
"mro",
}
if invalid_names:
raise ValueError(
"Invalid enum member name: {0}".format(",".join(invalid_names))
)

# create a default docstring if one has not been provided
if "__doc__" not in classdict:
classdict["__doc__"] = "An enumeration."

# create our new Enum type
enum_class = super().__new__(metacls, cls, bases, classdict)
enum_class._member_names_ = [] # names in definition order
enum_class._member_map_ = OrderedDict() # name->value map
enum_class._member_type_ = member_type

# save attributes from super classes so we know if we can take
# the shortcut of storing members in the class dict
base_attributes = {a for b in enum_class.mro() for a in b.__dict__}

# Reverse value->name map for hashable values.
enum_class._value2member_map_ = {}

# If a custom type is mixed into the Enum, and it does not know how
# to pickle itself, pickle.dumps will succeed but pickle.loads will
# fail. Rather than have the error show up later and possibly far
# from the source, sabotage the pickle protocol for this class so
# that pickle.dumps also fails.
#
# However, if the new class implements its own __reduce_ex__, do not
# sabotage -- it's on them to make sure it works correctly. We use
# __reduce_ex__ instead of any of the others as it is preferred by
# pickle over __reduce__, and it handles all pickle protocols.
if "__reduce_ex__" not in classdict:
if member_type is not object:
methods = (
"__getnewargs_ex__",
"__getnewargs__",
"__reduce_ex__",
"__reduce__",
)
if not any(m in member_type.__dict__ for m in methods):
_make_class_unpicklable(enum_class)

# instantiate them, checking for duplicates as we go
# we instantiate first instead of checking for duplicates first in case
# a custom __new__ is doing something funky with the values -- such as
# auto-numbering ;)
for member_name in classdict._member_names:
value = enum_members[member_name]
if not isinstance(value, tuple):
args = (value,)
else:
args = value
if member_type is tuple: # special case for tuple enums
args = (args,) # wrap it one more time
if not use_args:
enum_member = __new__(enum_class)
if not hasattr(enum_member, "_value_"):
enum_member._value_ = value
else:
enum_member = __new__(enum_class, *args)
if not hasattr(enum_member, "_value_"):
if member_type is object:
enum_member._value_ = value
else:
enum_member._value_ = member_type(*args)
value = enum_member._value_
enum_member._name_ = member_name
enum_member.__objclass__ = enum_class
enum_member.__init__(*args)
# If another member with the same value was already defined, the
# new member becomes an alias to the existing one.
for name, canonical_member in enum_class._member_map_.items():
if canonical_member._value_ == enum_member._value_:
enum_member = canonical_member
break
else:
# Aliases don't appear in member names (only in __members__).
enum_class._member_names_.append(member_name)
# performance boost for any member that would not shadow
# a DynamicClassAttribute
if member_name not in base_attributes:
setattr(enum_class, member_name, enum_member)
# now add to _member_map_
enum_class._member_map_[member_name] = enum_member
try:
# This may fail if value is not hashable. We can't add the value
# to the map, and by-value lookups for this value will be
# linear.
enum_class._value2member_map_[value] = enum_member
except TypeError:
pass

# double check that repr and friends are not the mixin's or various
# things break (such as pickle)
for name in ("__repr__", "__str__", "__format__", "__reduce_ex__"):
class_method = getattr(enum_class, name)
obj_method = getattr(member_type, name, None)
enum_method = getattr(first_enum, name, None)
if obj_method is not None and obj_method is class_method:
setattr(enum_class, name, enum_method)

# replace any other __new__ with our own (as long as Enum is not None,
# anyway) -- again, this is to support pickle
if Enum is not None:
# if the user defined their own __new__, save it before it gets
# clobbered in case they subclass later
if save_new:
enum_class.__new_member__ = __new__
enum_class.__new__ = Enum.__new__

# py3 support for definition order (helps keep py2/py3 code in sync)
if _order_ is not None:
if isinstance(_order_, str):
_order_ = _order_.replace(",", " ").split()
if _order_ != enum_class._member_names_:
raise TypeError("member order does not match _order_")

return enum_class

def __bool__(self):
"""
classes/types should always be True.
"""
return True

def __call__(
cls, value, names=None, *, module=None, qualname=None, type=None, start=1
):
"""
Either returns an existing member, or creates a new enum class.

This method is used both when an enum class is given a value to match
to an enumeration member (i.e. Color(3)) and for the functional API
(i.e. Color = Enum('Color', names='RED GREEN BLUE')).

When used for the functional API:

`value` will be the name of the new class.

`names` should be either a string of white-space/comma delimited names
(values will start at `start`), or an iterator/mapping of name, value pairs.

`module` should be set to the module this class is being created in;
if it is not set, an attempt to find that module will be made, but if
it fails the class will not be picklable.

`qualname` should be set to the actual location this class can be found
at in its module; by default it is set to the global scope. If this is
not correct, unpickling will fail in some circumstances.

`type`, if set, will be mixed in as the first base class.

"""
if names is None: # simple value lookup
return cls.__new__(cls, value)
# otherwise, functional API: we're creating a new Enum type
return cls._create_(
value, names, module=module, qualname=qualname, type=type, start=start
)

def __contains__(cls, member):
return isinstance(member, cls) and member._name_ in cls._member_map_

def __delattr__(cls, attr):
# nicer error message when someone tries to delete an attribute
# (see issue19025).
if attr in cls._member_map_:
raise AttributeError("%s: cannot delete Enum member." % cls.__name__)
super().__delattr__(attr)

def __dir__(self):
return [
"__class__",
"__doc__",
"__members__",
"__module__",
] + self._member_names_

def __getattr__(cls, name):
"""
Return the enum member matching `name`

We use __getattr__ instead of descriptors or inserting into the enum
class' __dict__ in order to support `name` and `value` being both
properties for enum members (which live in the class' __dict__) and
enum members themselves.

"""
if _is_dunder(name):
raise AttributeError(name)
try:
return cls._member_map_[name]
except KeyError:
raise AttributeError(name) from None

def __getitem__(cls, name):
return cls._member_map_[name]

def __iter__(cls):
return (cls._member_map_[name] for name in cls._member_names_)

def __len__(cls):
return len(cls._member_names_)

@property
def __members__(cls):
"""
Returns a mapping of member name->value.

This mapping lists all enum members, including aliases. Note that this
is a read-only view of the internal mapping.

"""
return MappingProxyType(cls._member_map_)

def __repr__(cls):
return "<enum %r>" % cls.__name__

def __reversed__(cls):
return (cls._member_map_[name] for name in reversed(cls._member_names_))

def __setattr__(cls, name, value):
"""
Block attempts to reassign Enum members.

A simple assignment to the class namespace only changes one of the
several possible ways to get an Enum member from the Enum class,
resulting in an inconsistent Enumeration.

"""
member_map = cls.__dict__.get("_member_map_", {})
if name in member_map:
raise AttributeError("Cannot reassign members.")
super().__setattr__(name, value)

def _create_(
cls, class_name, names=None, *, module=None, qualname=None, type=None, start=1
):
"""
Convenience method to create a new Enum class.

`names` can be:

* A string containing member names, separated either with spaces or
commas. Values are incremented by 1 from `start`.
* An iterable of member names. Values are incremented by 1 from `start`.
* An iterable of (member name, value) pairs.
* A mapping of member name -> value pairs.

"""
metacls = cls.__class__
bases = (cls,) if type is None else (type, cls)
_, first_enum = cls._get_mixins_(bases)
classdict = metacls.__prepare__(class_name, bases)

# special processing needed for names?
if isinstance(names, str):
names = names.replace(",", " ").split()
if isinstance(names, (tuple, list)) and names and isinstance(names[0], str):
original_names, names = names, []
last_values = []
for count, name in enumerate(original_names):
value = first_enum._generate_next_value_(
name, start, count, last_values[:]
)
last_values.append(value)
names.append((name, value))

# Here, names is either an iterable of (name, value) or a mapping.
for item in names:
if isinstance(item, str):
member_name, member_value = item, names[item]
else:
member_name, member_value = item
classdict[member_name] = member_value
enum_class = metacls.__new__(metacls, class_name, bases, classdict)

# TODO: replace the frame hack if a blessed way to know the calling
# module is ever developed
if module is None:
try:
module = sys._getframe(2).f_globals["__name__"]
except (AttributeError, ValueError) as exc:
pass
if module is None:
_make_class_unpicklable(enum_class)
else:
enum_class.__module__ = module
if qualname is not None:
enum_class.__qualname__ = qualname

return enum_class

@staticmethod
def _get_mixins_(bases):
"""
Returns the type for creating enum members, and the first inherited
enum class.

bases: the tuple of bases that was given to __new__

"""
if not bases:
return object, Enum

# double check that we are not subclassing a class with existing
# enumeration members; while we're at it, see if any other data
# type has been mixed in so we can use the correct __new__
member_type = first_enum = None
for base in bases:
if base is not Enum and issubclass(base, Enum) and base._member_names_:
raise TypeError("Cannot extend enumerations")
# base is now the last base in bases
if not issubclass(base, Enum):
raise TypeError(
"new enumerations must be created as "
"`ClassName([mixin_type,] enum_type)`"
)

# get correct mix-in type (either mix-in type of Enum subclass, or
# first base if last base is Enum)
if not issubclass(bases[0], Enum):
member_type = bases[0] # first data type
first_enum = bases[-1] # enum type
else:
for base in bases[0].__mro__:
# most common: (IntEnum, int, Enum, object)
# possible: (<Enum 'AutoIntEnum'>, <Enum 'IntEnum'>,
# <class 'int'>, <Enum 'Enum'>,
# <class 'object'>)
if issubclass(base, Enum):
if first_enum is None:
first_enum = base
else:
if member_type is None:
member_type = base

return member_type, first_enum

@staticmethod
def _find_new_(classdict, member_type, first_enum):
"""
Returns the __new__ to be used for creating the enum members.

classdict: the class dictionary given to __new__
member_type: the data type whose __new__ will be used by default
first_enum: enumeration to check for an overriding __new__

"""
# now find the correct __new__, checking to see of one was defined
# by the user; also check earlier enum classes in case a __new__ was
# saved as __new_member__
__new__ = classdict.get("__new__", None)

# should __new__ be saved as __new_member__ later?
save_new = __new__ is not None

if __new__ is None:
# check all possibles for __new_member__ before falling back to
# __new__
for method in ("__new_member__", "__new__"):
for possible in (member_type, first_enum):
target = getattr(possible, method, None)
if target not in {
None,
None.__new__,
object.__new__,
Enum.__new__,
}:
__new__ = target
break
if __new__ is not None:
break
else:
__new__ = object.__new__

# if a non-object.__new__ is used then whatever value/tuple was
# assigned to the enum member name will be passed to __new__ and to the
# new enum member's __init__
if __new__ is object.__new__:
use_args = False
else:
use_args = True

return __new__, save_new, use_args


class Enum(metaclass=EnumMeta):
"""
Generic enumeration.

Derive from this class to define new enumerations.

"""

def __new__(cls, value):
# all enum instances are actually created during class construction
# without calling this method; this method is called by the metaclass'
# __call__ (i.e. Color(3) ), and by pickle
if type(value) is cls:
# For lookups like Color(Color.RED)
return value
# by-value search for a matching enum member
# see if it's in the reverse mapping (for hashable values)
try:
if value in cls._value2member_map_:
return cls._value2member_map_[value]
except TypeError:
# not there, now do long search -- O(n) behavior
for member in cls._member_map_.values():
if member._value_ == value:
return member
# still not found -- try _missing_ hook
return cls._missing_(value)

def _generate_next_value_(name, start, count, last_values):
for last_value in reversed(last_values):
try:
return last_value + 1
except TypeError:
pass
else:
return start

@classmethod
def _missing_(cls, value):
raise ValueError("%r is not a valid %s" % (value, cls.__name__))

def __repr__(self):
return "<%s.%s: %r>" % (self.__class__.__name__, self._name_, self._value_)

def __str__(self):
return "%s.%s" % (self.__class__.__name__, self._name_)

def __dir__(self):
added_behavior = [
m
for cls in self.__class__.mro()
for m in cls.__dict__
if m[0] != "_" and m not in self._member_map_
]
return ["__class__", "__doc__", "__module__"] + added_behavior

def __format__(self, format_spec):
# mixed-in Enums should use the mixed-in type's __format__, otherwise
# we can get strange results with the Enum name showing up instead of
# the value

# pure Enum branch
if self._member_type_ is object:
cls = str
val = str(self)
# mix-in branch
else:
cls = self._member_type_
val = self._value_
return cls.__format__(val, format_spec)

def __hash__(self):
return hash(self._name_)

def __reduce_ex__(self, proto):
return self.__class__, (self._value_,)

# DynamicClassAttribute is used to provide access to the `name` and
# `value` properties of enum members while keeping some measure of
# protection from modification, while still allowing for an enumeration
# to have members named `name` and `value`. This works because enumeration
# members are not set directly on the enum class -- __getattr__ is
# used to look them up.

@DynamicClassAttribute
def name(self):
"""The name of the Enum member."""
return self._name_

@DynamicClassAttribute
def value(self):
"""The value of the Enum member."""
return self._value_

@classmethod
def _convert(cls, name, module, filter, source=None):
"""
Create a new Enum subclass that replaces a collection of global constants
"""
# convert all constants from source (or module) that pass filter() to
# a new Enum called name, and export the enum and its members back to
# module;
# also, replace the __reduce_ex__ method so unpickling works in
# previous Python versions
module_globals = vars(sys.modules[module])
if source:
source = vars(source)
else:
source = module_globals
# We use an OrderedDict of sorted source keys so that the
# _value2member_map is populated in the same order every time
# for a consistent reverse mapping of number to name when there
# are multiple names for the same number rather than varying
# between runs due to hash randomization of the module dictionary.
members = [(name, source[name]) for name in source.keys() if filter(name)]
try:
# sort by value
members.sort(key=lambda t: (t[1], t[0]))
except TypeError:
# unless some values aren't comparable, in which case sort by name
members.sort(key=lambda t: t[0])
cls = cls(name, members, module=module)
cls.__reduce_ex__ = _reduce_ex_by_name
module_globals.update(cls.__members__)
module_globals[name] = cls
return cls


class IntEnum(int, Enum):
"""Enum where members are also (and must be) ints"""


def _reduce_ex_by_name(self, proto):
return self.name


class Flag(Enum):
"""Support for flags"""

def _generate_next_value_(name, start, count, last_values):
"""
Generate the next value when not given.

name: the name of the member
start: the initital start value or None
count: the number of existing members
last_value: the last value assigned or None
"""
if not count:
return start if start is not None else 1
for last_value in reversed(last_values):
try:
high_bit = _high_bit(last_value)
break
except Exception:
raise TypeError("Invalid Flag value: %r" % last_value) from None
return 2 ** (high_bit + 1)

@classmethod
def _missing_(cls, value):
original_value = value
if value < 0:
value = ~value
possible_member = cls._create_pseudo_member_(value)
if original_value < 0:
possible_member = ~possible_member
return possible_member

@classmethod
def _create_pseudo_member_(cls, value):
"""
Create a composite member iff value contains only members.
"""
pseudo_member = cls._value2member_map_.get(value, None)
if pseudo_member is None:
# verify all bits are accounted for
_, extra_flags = _decompose(cls, value)
if extra_flags:
raise ValueError("%r is not a valid %s" % (value, cls.__name__))
# construct a singleton enum pseudo-member
pseudo_member = object.__new__(cls)
pseudo_member._name_ = None
pseudo_member._value_ = value
# use setdefault in case another thread already created a composite
# with this value
pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member)
return pseudo_member

def __contains__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return other._value_ & self._value_ == other._value_

def __repr__(self):
cls = self.__class__
if self._name_ is not None:
return "<%s.%s: %r>" % (cls.__name__, self._name_, self._value_)
members, uncovered = _decompose(cls, self._value_)
return "<%s.%s: %r>" % (
cls.__name__,
"|".join([str(m._name_ or m._value_) for m in members]),
self._value_,
)

def __str__(self):
cls = self.__class__
if self._name_ is not None:
return "%s.%s" % (cls.__name__, self._name_)
members, uncovered = _decompose(cls, self._value_)
if len(members) == 1 and members[0]._name_ is None:
return "%s.%r" % (cls.__name__, members[0]._value_)
else:
return "%s.%s" % (
cls.__name__,
"|".join([str(m._name_ or m._value_) for m in members]),
)

def __bool__(self):
return bool(self._value_)

def __or__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return self.__class__(self._value_ | other._value_)

def __and__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return self.__class__(self._value_ & other._value_)

def __xor__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return self.__class__(self._value_ ^ other._value_)

def __invert__(self):
members, uncovered = _decompose(self.__class__, self._value_)
inverted_members = [
m
for m in self.__class__
if m not in members and not m._value_ & self._value_
]
inverted = reduce(_or_, inverted_members, self.__class__(0))
return self.__class__(inverted)


class IntFlag(int, Flag):
"""Support for integer-based Flags"""

@classmethod
def _missing_(cls, value):
if not isinstance(value, int):
raise ValueError("%r is not a valid %s" % (value, cls.__name__))
new_member = cls._create_pseudo_member_(value)
return new_member

@classmethod
def _create_pseudo_member_(cls, value):
pseudo_member = cls._value2member_map_.get(value, None)
if pseudo_member is None:
need_to_create = [value]
# get unaccounted for bits
_, extra_flags = _decompose(cls, value)
# timer = 10
while extra_flags:
# timer -= 1
bit = _high_bit(extra_flags)
flag_value = 2 ** bit
if (
flag_value not in cls._value2member_map_
and flag_value not in need_to_create
):
need_to_create.append(flag_value)
if extra_flags == -flag_value:
extra_flags = 0
else:
extra_flags ^= flag_value
for value in reversed(need_to_create):
# construct singleton pseudo-members
pseudo_member = int.__new__(cls, value)
pseudo_member._name_ = None
pseudo_member._value_ = value
# use setdefault in case another thread already created a composite
# with this value
pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member)
return pseudo_member

def __or__(self, other):
if not isinstance(other, (self.__class__, int)):
return NotImplemented
result = self.__class__(self._value_ | self.__class__(other)._value_)
return result

def __and__(self, other):
if not isinstance(other, (self.__class__, int)):
return NotImplemented
return self.__class__(self._value_ & self.__class__(other)._value_)

def __xor__(self, other):
if not isinstance(other, (self.__class__, int)):
return NotImplemented
return self.__class__(self._value_ ^ self.__class__(other)._value_)

__ror__ = __or__
__rand__ = __and__
__rxor__ = __xor__

def __invert__(self):
result = self.__class__(~self._value_)
return result


def _high_bit(value):
"""returns index of highest bit, or -1 if value is zero or negative"""
return value.bit_length() - 1


def unique(enumeration):
"""Class decorator for enumerations ensuring unique member values."""
duplicates = []
for name, member in enumeration.__members__.items():
if name != member.name:
duplicates.append((name, member.name))
if duplicates:
alias_details = ", ".join(
["%s -> %s" % (alias, name) for (alias, name) in duplicates]
)
raise ValueError(
"duplicate values found in %r: %s" % (enumeration, alias_details)
)
return enumeration


def _decompose(flag, value):
"""Extract all members from the value."""
# _decompose is only called if the value is not named
not_covered = value
negative = value < 0
# issue29167: wrap accesses to _value2member_map_ in a list to avoid race
# conditions between iterating over it and having more psuedo-
# members added to it
if negative:
# only check for named flags
flags_to_check = [
(m, v)
for v, m in list(flag._value2member_map_.items())
if m.name is not None
]
else:
# check for named flags and powers-of-two flags
flags_to_check = [
(m, v)
for v, m in list(flag._value2member_map_.items())
if m.name is not None or _power_of_two(v)
]
members = []
for member, member_value in flags_to_check:
if member_value and member_value & value == member_value:
members.append(member)
not_covered &= ~member_value
if not members and value in flag._value2member_map_:
members.append(flag._value2member_map_[value])
members.sort(key=lambda m: m._value_, reverse=True)
if len(members) > 1 and members[0].value == value:
# we have the breakdown, don't need the value member itself
members.pop(0)
return members, not_covered


def _power_of_two(value):
if value < 1:
return False
return value == 2 ** _high_bit(value)

+ 0
- 94
imperative/python/megengine/core/ops/_internal/helper.py View File

@@ -1,94 +0,0 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import warnings

from ..._imperative_rt.ops import OprAttr
from . import param_defs


def make_param(param, ptype, kwargs):
if param is not None:
if isinstance(param, ptype):
return param

param = [param]
assert len(param) == len(
ptype.__slots__
), "{} needs {} params, but {} are provided".format(
ptype, len(ptype.__slots__), len(param)
)
return ptype(*param)

ckw = {}
for i in ptype.__slots__:
val = kwargs.pop(i, ckw)
if val is not ckw:
ckw[i] = val
return ptype(**ckw)


class PodOpVisitor:
__name2subclass = {}
__c = None

name = None
param_names = []
config = None

def __init__(self, config, **params):
self.config = config
assert set(params) == set(self.param_names)
self.__dict__.update(params)

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs) # python 3.5 does not have this
name = cls.name
if name in cls.__name2subclass:
if not issubclass(cls, cls.__name2subclass[name]):
warnings.warn("Multiple subclasses for bultin op: %s" % name)
cls.__name2subclass[name] = cls

def to_c(self):
if self.__c:
return self.__c
op = OprAttr()
op.type = self.name
if self.config is not None:
op.config = self.config
# first 4 bytes is TAG, has to remove them currently
op.param = b"".join(self.__dict__[k].serialize()[4:] for k in self.param_names)
self.__c = op
return op

def __eq__(self, rhs):
return self.to_c() == rhs.to_c()

def __repr__(self):
name = self.__class__.__name__

if self.__c:
return "{}(<binary data>)".format(name)

kwargs = {}
for i in self.param_names:
p = self.__dict__[i]
if isinstance(p, param_defs._ParamDefBase):
for k in p.__slots__:
v = getattr(p, k)
if isinstance(v, param_defs._EnumBase):
v = v.name
kwargs[k] = repr(v)
else:
kwargs[i] = repr(p)
if self.config:
if len(self.config.comp_node_arr) == 1:
kwargs["device"] = "'%s'" % self.config.comp_node
return "{}({})".format(
name, ", ".join("{}={}".format(k, v) for k, v in kwargs.items())
)

+ 0
- 194
imperative/python/megengine/core/ops/_internal/misc_ops.py View File

@@ -1,194 +0,0 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
import ctypes

from ..._imperative_rt import OperatorNodeConfig as Config
from . import param_defs
from .helper import PodOpVisitor, make_param

__all__ = ["ConvolutionBackwardData", "Dimshuffle", "Reshape", "AxisAddRemove"]


class TensorShape:
MAX_NDIM = 7


class ConvolutionBackwardData(PodOpVisitor):
param_names = (
"param",
"execution_polity",
)
name = "ConvolutionBackwardDataV1"

def __init__(
self,
*,
param=None,
execution_polity=None,
name=None,
comp_node=None,
config=None,
dtype=None,
**kwargs
):
config = config or Config()
if name:
config.name = name
if comp_node:
config.comp_node = comp_node
if dtype:
config.dtype = dtype
self.config = config

self.param = make_param(param, param_defs.Convolution, kwargs)
self.execution_polity = make_param(
execution_polity, param_defs.ExecutionPolicy, kwargs
)
assert not kwargs, "extra kwargs: {}".format(kwargs)


class Dimshuffle(PodOpVisitor):
name = "Dimshuffle"
param_names = ("pattern",)

class Pattern(ctypes.Structure):
Pattern_Array = ctypes.c_int32 * TensorShape.MAX_NDIM
_fields_ = [
("length", ctypes.c_uint32),
("pattern", Pattern_Array),
("ndim", ctypes.c_uint32),
]

def serialize(self):
return bytes(ctypes.c_uint32(0)) + bytes(self)

def __init__(self, pattern, ndim=0):
assert isinstance(pattern, collections.abc.Iterable)
assert len(pattern) <= TensorShape.MAX_NDIM
pattern_array = Dimshuffle.Pattern.Pattern_Array()
for idx, v in enumerate(pattern):
pattern_array[idx] = ctypes.c_int32(-1 if v == "x" else int(v))
self.pattern = Dimshuffle.Pattern(len(pattern), pattern_array, ndim)


class Reshape(PodOpVisitor):
name = "ReshapeV1"
param_names = ("unspec_axis",)

def __init__(self, unspec_axis=None):
if unspec_axis is None:
self.unspec_axis = param_defs.OptionalAxisV1()
else:
self.unspec_axis = param_defs.OptionalAxisV1(unspec_axis)


class AxisNum(ctypes.Structure):
_fields_ = [
("m_num", ctypes.c_int),
]


class AxisDesc(ctypes.Structure):
class Method(ctypes.c_int):
ADD_1 = 0
REMOVE = 1

_fields_ = [
("method", Method),
("axis", AxisNum),
]

@classmethod
def make_add(cls, axis):
return cls(cls.Method.ADD_1, AxisNum(axis))

@classmethod
def make_remove(cls, axis):
return cls(cls.Method.REMOVE, AxisNum(axis))


class AxisAddRemove(PodOpVisitor):
name = "AxisAddRemove"
param_names = ("param",)

AxisDesc = AxisDesc

class Param(ctypes.Structure):
MAX_DESC_SIZE = TensorShape.MAX_NDIM * 2

_fields_ = [("nr_desc", ctypes.c_uint32), ("desc", AxisDesc * MAX_DESC_SIZE)]

def __init__(self, *args):
super().__init__()
self.nr_desc = len(args)
for i, a in enumerate(args):
self.desc[i] = a

def serialize(self):
return bytes(ctypes.c_uint32(0)) + bytes(self)

def __init__(self, param):
assert isinstance(param, self.Param)
self.param = param


del AxisDesc


class IndexingOpBase(PodOpVisitor):
param_names = ("index_desc",)

class IndexDescMaskDump(ctypes.Structure):
class Item(ctypes.Structure):
_fields_ = [
("axis", ctypes.c_int8),
("begin", ctypes.c_bool),
("end", ctypes.c_bool),
("step", ctypes.c_bool),
("idx", ctypes.c_bool),
]

Item_Array = Item * TensorShape.MAX_NDIM

_fields_ = [("nr_item", ctypes.c_uint8), ("items", Item_Array)]

def serialize(self):
return bytes(ctypes.c_uint32(0)) + bytes(self)

def __init__(self, items):
nr_item = len(items)
assert nr_item <= TensorShape.MAX_NDIM
item_array = IndexingOpBase.IndexDescMaskDump.Item_Array()
for idx, item in enumerate(items):
assert isinstance(item, (tuple, list)) and len(item) == 5
item_array[idx] = IndexingOpBase.IndexDescMaskDump.Item(*item)
self.index_desc = IndexingOpBase.IndexDescMaskDump(nr_item, item_array)


def _gen_indexing_defs(*names):
for name in names:
globals()[name] = type(name, (IndexingOpBase,), dict(name=name))
__all__.append(name)


_gen_indexing_defs(
"Subtensor",
"SetSubtensor",
"IncrSubtensor",
"IndexingMultiAxisVec",
"IndexingSetMultiAxisVec",
"IndexingIncrMultiAxisVec",
"MeshIndexing",
"IncrMeshIndexing",
"SetMeshIndexing",
"BatchedMeshIndexing",
"BatchedIncrMeshIndexing",
"BatchedSetMeshIndexing",
)

+ 1
- 14
imperative/python/megengine/core/ops/builtin/__init__.py View File

@@ -11,25 +11,12 @@ from typing import Union

from ..._imperative_rt import OpDef, ops
from ...tensor.core import OpBase, TensorBase, TensorWrapperBase, apply
from .._internal import all_ops
from .._internal.helper import PodOpVisitor

# register OpDef as a "virtual subclass" of OpBase, so any of registered
# apply(OpBase, ...) rules could work well on OpDef
OpBase.register(OpDef)

# forward to apply(OpDef, ...)
@apply.register()
def _(op: PodOpVisitor, *args: Union[TensorBase, TensorWrapperBase]):
return apply(op.to_c(), *args)


__all__ = ["OpDef", "PodOpVisitor"]

for k, v in all_ops.__dict__.items():
if isinstance(v, type) and issubclass(v, PodOpVisitor):
globals()[k] = v
__all__.append(k)
__all__ = ["OpDef"]

for k, v in ops.__dict__.items():
if isinstance(v, type) and issubclass(v, OpDef):


+ 2
- 5
imperative/python/megengine/core/tensor/tensor_wrapper.py View File

@@ -90,7 +90,7 @@ def _reshape(x, shape):
if unspec_axis is None:
op = builtin.Reshape()
else:
op = builtin.Reshape(unspec_axis=unspec_axis)
op = builtin.Reshape(axis=unspec_axis)
(x,) = apply(op, x, shape)
return x

@@ -144,8 +144,6 @@ def _logical_binary_elwise(mode, rev=False):


def _remove_axis(inp: Tensor, axis) -> Tensor:
Param = builtin.AxisAddRemove.Param

def get_axes():
if axis is None:
return [i for i, s in enumerate(inp.shape) if s == 1]
@@ -159,8 +157,7 @@ def _remove_axis(inp: Tensor, axis) -> Tensor:
axis = sorted(i + inp.ndim if i < 0 else i for i in axis)
axis = [a - i for i, a in enumerate(axis)]

param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_remove, axis))
op = builtin.AxisAddRemove(param=param)
op = builtin.RemoveAxis(axis=axis)
(result,) = apply(op, inp)
if len(axis) == inp.ndim:
setscalar(result)


+ 1
- 1
imperative/python/megengine/core/tensor/utils.py View File

@@ -134,7 +134,7 @@ def astype(x, dtype):
dtype = np.dtype(dtype)
if not is_equal(x.dtype, dtype):
isscalar = x.__wrapped__._data._isscalar
(x,) = apply(builtin.TypeCvt(param=dtype), x)
(x,) = apply(builtin.TypeCvt(dtype=dtype), x)
x.__wrapped__._data._isscalar = isscalar
return x



+ 24
- 22
imperative/python/megengine/distributed/functional.py View File

@@ -8,7 +8,6 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Optional, Tuple

from ..core._imperative_rt.ops import CollectiveCommMode
from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn
from ..core.autodiff.grad import (
Tracer,
@@ -110,17 +109,20 @@ def collective_comm(inp, mode, group, device):
assert isinstance(group, Group)
if group is None:
return inp
op = CollectiveComm()
op.key = group.key
op.nr_devices = group.size
op.rank = group.rank
op.is_root = op.rank == 0
op.local_grad = False
op.addr, op.port = get_mm_server_addr()
op.mode = mode
op.dtype = inp.dtype
op.backend = get_backend()
op.comp_node = device
addr, port = get_mm_server_addr()
op = CollectiveComm(
key=group.key,
nr_devices=group.size,
rank=group.rank,
is_root=(group.rank == 0),
local_grad=False,
addr=addr,
port=port,
mode=mode,
dtype=inp.dtype,
backend=get_backend(),
comp_node=device,
)
return apply(op, inp)[0]


@@ -134,7 +136,7 @@ def reduce_sum(
:param group: communication group.
:param device: execution device.
"""
mode = CollectiveCommMode.REDUCE_SUM
mode = CollectiveComm.Mode.REDUCE_SUM
return collective_comm(inp, mode, group, device)


@@ -148,7 +150,7 @@ def broadcast(
:param group: communication group.
:param device: execution device.
"""
mode = CollectiveCommMode.BROADCAST
mode = CollectiveComm.Mode.BROADCAST
return collective_comm(inp, mode, group, device)


@@ -162,7 +164,7 @@ def all_gather(
:param group: communication group.
:param device: execution device.
"""
mode = CollectiveCommMode.ALL_GATHER
mode = CollectiveComm.Mode.ALL_GATHER
return collective_comm(inp, mode, group, device)


@@ -176,7 +178,7 @@ def reduce_scatter_sum(
:param group: communication group.
:param device: execution device.
"""
mode = CollectiveCommMode.REDUCE_SCATTER_SUM
mode = CollectiveComm.Mode.REDUCE_SCATTER_SUM
return collective_comm(inp, mode, group, device)


@@ -190,7 +192,7 @@ def all_reduce_sum(
:param group: communication group.
:param device: execution device.
"""
mode = CollectiveCommMode.ALL_REDUCE_SUM
mode = CollectiveComm.Mode.ALL_REDUCE_SUM
return collective_comm(inp, mode, group, device)


@@ -204,7 +206,7 @@ def all_reduce_max(
:param group: communication group.
:param device: execution device.
"""
mode = CollectiveCommMode.ALL_REDUCE_MAX
mode = CollectiveComm.Mode.ALL_REDUCE_MAX
return collective_comm(inp, mode, group, device)


@@ -218,7 +220,7 @@ def all_reduce_min(
:param group: communication group.
:param device: execution device.
"""
mode = CollectiveCommMode.ALL_REDUCE_MIN
mode = CollectiveComm.Mode.ALL_REDUCE_MIN
return collective_comm(inp, mode, group, device)


@@ -232,7 +234,7 @@ def gather(
:param group: communication group.
:param device: execution device.
"""
mode = CollectiveCommMode.GATHER
mode = CollectiveComm.Mode.GATHER
return collective_comm(inp, mode, group, device)


@@ -246,7 +248,7 @@ def scatter(
:param group: communication group.
:param device: execution device.
"""
mode = CollectiveCommMode.SCATTER
mode = CollectiveComm.Mode.SCATTER
return collective_comm(inp, mode, group, device)


@@ -260,7 +262,7 @@ def all_to_all(
:param group: communication group.
:param device: execution device.
"""
mode = CollectiveCommMode.ALL_TO_ALL
mode = CollectiveComm.Mode.ALL_TO_ALL
return collective_comm(inp, mode, group, device)




+ 0
- 20
imperative/python/megengine/functional/elemwise.py View File

@@ -73,27 +73,7 @@ __all__ = [
]


class _ElemwiseMode(Elemwise.Mode):
@classmethod
def __normalize(cls, val):
if isinstance(val, str):
if not hasattr(cls, "__member_upper_dict__"):
cls.__member_upper_dict__ = {
k.upper(): v for k, v in cls.__members__.items()
}
val = cls.__member_upper_dict__.get(val.upper(), val)
return val

@classmethod
def convert(cls, val):
val = cls.__normalize(val)
if isinstance(val, cls):
return val
return cls(val)


def _elwise(*args, mode):
mode = _ElemwiseMode.convert(mode)
op = builtin.Elemwise(mode)
tensor_args = list(
filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args)


+ 7
- 9
imperative/python/megengine/functional/math.py View File

@@ -13,7 +13,6 @@ import numbers
from typing import Optional, Sequence, Tuple, Union

from ..core.ops import builtin
from ..core.ops._internal import param_defs as P
from ..core.ops.special import Const
from ..core.tensor import utils
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
@@ -601,9 +600,9 @@ def argsort(inp: Tensor, descending: bool = False) -> Tensor:
"""
assert len(inp.shape) <= 2, "Input should be 1d or 2d"
if descending:
order = P.Argsort.Order.DESCENDING
order = "DESCENDING"
else:
order = P.Argsort.Order.ASCENDING
order = "ASCENDING"

op = builtin.Argsort(order=order)
if len(inp.shape) == 1:
@@ -643,9 +642,9 @@ def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]:
"""
assert len(inp.shape) <= 2, "Input should be 1d or 2d"
if descending:
order = P.Argsort.Order.DESCENDING
order = "DESCENDING"
else:
order = P.Argsort.Order.ASCENDING
order = "ASCENDING"

op = builtin.Argsort(order=order)
if len(inp.shape) == 1:
@@ -695,13 +694,12 @@ def topk(
if descending:
inp = -inp

Mode = P.TopK.Mode
if kth_only:
mode = Mode.KTH_ONLY
mode = "KTH_ONLY"
elif no_sort:
mode = Mode.VALUE_IDX_NOSORT
mode = "VALUE_IDX_NOSORT"
else:
mode = Mode.VALUE_IDX_SORTED
mode = "VALUE_IDX_SORTED"
op = builtin.TopK(mode=mode)

if not isinstance(k, (TensorBase, TensorWrapperBase)):


+ 10
- 19
imperative/python/megengine/functional/nn.py View File

@@ -12,7 +12,6 @@ from typing import Optional, Sequence, Tuple, Union
from ..core._imperative_rt import CompNode
from ..core._trace_option import use_symbolic_shape
from ..core.ops import builtin
from ..core.ops._internal import param_defs as P
from ..core.ops.builtin import BatchNorm
from ..core.ops.special import Const
from ..core.tensor import megbrain_graph, utils
@@ -121,11 +120,11 @@ def conv2d(
``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and the shape of weight should be `(groups, out_channel // groups,
in_channels // groups, height, width)`.
:type conv_mode: string or :class:`P.Convolution.Mode`
:type conv_mode: string or :class:`Convolution.Mode`
:param conv_mode: supports "CROSS_CORRELATION". Default:
"CROSS_CORRELATION"
:type compute_mode: string or
:class:`P.Convolution.ComputeMode`
:class:`Convolution.ComputeMode`
:param compute_mode: when set to "DEFAULT", no special requirements will be
placed on the precision of intermediate results. When set to "FLOAT32",
"Float32" would be used for accumulator and intermediate result, but only
@@ -139,8 +138,8 @@ def conv2d(
pad_h, pad_w = expand_hw(padding)
dilate_h, dilate_w = expand_hw(dilation)

Sparse = P.Convolution.Sparse
sparse_type = Sparse.DENSE if groups == 1 else Sparse.GROUP
Sparse = builtin.Convolution.Sparse
sparse_type = "DENSE" if groups == 1 else "GROUP"
op = builtin.Convolution(
stride_h=stride_h,
stride_w=stride_w,
@@ -187,11 +186,11 @@ def conv_transpose2d(
``in_channels`` and ``out_channels`` must be divisible by groups,
and the shape of weight should be `(groups, out_channel // groups,
in_channels // groups, height, width)`. Default: 1
:type conv_mode: string or :class:`P.Convolution.Mode`
:type conv_mode: string or :class:`Convolution.Mode`
:param conv_mode: supports "CROSS_CORRELATION". Default:
"CROSS_CORRELATION"
:type compute_mode: string or
:class:`P.Convolution.ComputeMode`
:class:`Convolution.ComputeMode`
:param compute_mode: when set to "DEFAULT", no special requirements will be
placed on the precision of intermediate results. When set to "FLOAT32",
"Float32" would be used for accumulator and intermediate result, but only
@@ -240,8 +239,6 @@ def local_conv2d(
pad_h, pad_w = expand_hw(padding)
dilate_h, dilate_w = expand_hw(dilation)

Sparse = P.Convolution.Sparse

op = builtin.GroupLocal(
stride_h=stride_h,
stride_w=stride_w,
@@ -251,7 +248,7 @@ def local_conv2d(
dilate_w=dilate_w,
mode=conv_mode,
compute_mode="DEFAULT",
sparse=Sparse.DENSE,
sparse="DENSE",
)
inp, weight = utils.convert_inputs(inp, weight)
(output,) = apply(op, inp, weight)
@@ -696,19 +693,14 @@ def batch_norm(

if not training:
op = builtin.BatchNorm(
BatchNorm.ParamDim.DIM_1C11, BatchNorm.FwdMode.INFERENCE, eps, 1.0, 1.0, 0.0
fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps, param_dim="DIM_1C11"
)
ret = apply(op, inp, weight, bias, running_mean, running_var)[-1]
return ret

else:
op = builtin.BatchNorm(
BatchNorm.ParamDim.DIM_1C11,
BatchNorm.FwdMode.TRAINING,
eps,
1.0 - momentum,
1.0,
0.0,
avg_factor=1 - momentum, epsilon=eps, param_dim="DIM_1C11"
)
if has_mean or has_var:
running_mean = make_full_if_none(running_mean, 0)
@@ -1638,8 +1630,7 @@ def conv1d(
pad_h = padding
dilate_h = dilation

Sparse = P.Convolution.Sparse
sparse_type = Sparse.DENSE if groups == 1 else Sparse.GROUP
sparse_type = "DENSE" if groups == 1 else "GROUP"
op = builtin.Convolution(
stride_h=stride_h,
stride_w=1,


+ 6
- 6
imperative/python/megengine/functional/quantized.py View File

@@ -41,12 +41,12 @@ def conv_bias_activation(
``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and the shape of weight should be `(groups, out_channel // groups,
in_channels // groups, height, width)`.
:type conv_mode: string or :class:`P.Convolution.Mode`.
:type conv_mode: string or :class:`Convolution.Mode`.
:param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default:
'CROSS_CORRELATION'
:param dtype: support for ``np.dtype``, Default: np.int8
:type compute_mode: string or
:class:`P.Convolution.ComputeMode`.
:class:`Convolution.ComputeMode`.
:param compute_mode: when set to "DEFAULT", no special requirements will be
placed on the precision of intermediate results. When set to "FLOAT32",
"Float32" would be used for accumulator and intermediate result, but only effective when input and output are of Float16 dtype.
@@ -56,7 +56,7 @@ def conv_bias_activation(
sh, sw = _pair_nonzero(stride)
dh, dw = _pair_nonzero(dilation)
sparse_type = "DENSE" if groups == 1 else "GROUP"
op = builtin.ConvBiasForward(
op = builtin.ConvBias(
stride_h=sh,
stride_w=sw,
pad_h=ph,
@@ -101,12 +101,12 @@ def batch_conv_bias_activation(
``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and the shape of weight should be `(groups, out_channel // groups,
in_channels // groups, height, width)`.
:type conv_mode: string or :class:`P.Convolution.Mode`.
:type conv_mode: string or :class:`Convolution.Mode`.
:param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default:
'CROSS_CORRELATION'
:param dtype: support for ``np.dtype``, Default: np.int8
:type compute_mode: string or
:class:`P.Convolution.ComputeMode`.
:class:`Convolution.ComputeMode`.
:param compute_mode: when set to "DEFAULT", no special requirements will be
placed on the precision of intermediate results. When set to "FLOAT32",
"Float32" would be used for accumulator and intermediate result, but only effective when input and output are of Float16 dtype.
@@ -116,7 +116,7 @@ def batch_conv_bias_activation(
sh, sw = _pair_nonzero(stride)
dh, dw = _pair_nonzero(dilation)
sparse_type = "DENSE" if groups == 1 else "GROUP"
op = builtin.BatchConvBiasForward(
op = builtin.BatchConvBias(
stride_h=sh,
stride_w=sw,
pad_h=ph,


+ 2
- 9
imperative/python/megengine/functional/tensor.py View File

@@ -16,7 +16,6 @@ import numpy as np
from ..core._imperative_rt import CompNode
from ..core._wrap import device as as_device
from ..core.ops import builtin
from ..core.ops._internal import param_defs as P
from ..core.ops.special import Const
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
from ..core.tensor.tensor_wrapper import _broadcast, _remove_axis
@@ -722,7 +721,7 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor:
[1 0]]

"""
return inp.transpose(pattern)
return inp.transpose(list(-1 if _ == "x" else _ for _ in pattern))


def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor:
@@ -756,10 +755,6 @@ def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor:
return inp.reshape(target_shape)


AxisAddRemove = builtin.AxisAddRemove
AxisDesc = AxisAddRemove.AxisDesc


def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor:
r"""
Reshapes the tensor by flattening the sub-tensor from dimension ``start_axis`` to dimension ``end_axis``.
@@ -826,7 +821,6 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
(1, 2)

"""
Param = builtin.AxisAddRemove.Param

def get_axes():
try:
@@ -839,8 +833,7 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
ndim = inp.ndim + len(axis)
axis = sorted(i + ndim if i < 0 else i for i in axis)

param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_add, axis))
op = builtin.AxisAddRemove(param=param)
op = builtin.AddAxis(axis=axis)
(result,) = apply(op, inp)
return result



+ 3
- 9
imperative/python/megengine/jit/tracing.py View File

@@ -21,9 +21,10 @@ import numpy as np
from ..core._imperative_rt import GraphProfiler
from ..core._imperative_rt.ops import (
CollectiveComm,
OprAttr,
GaussianRNG,
RemoteRecv,
RemoteSend,
UniformRNG,
VirtualDep,
)
from ..core._trace_option import set_symbolic_shape
@@ -182,14 +183,7 @@ class trace:
record = self._seq[self._pc]
op_, ihandles, ohandles = record
if op != op_:
# FIXME: will be removed once better rng implementation is done
if isinstance(op, OprAttr) and (
op.type in ("UniformRNG", "GaussianRNG") and op.type == op_.type
):
if op.param[8:] != op_.param[8:]:
raise TraceMismatchError("op different from last time")
else:
raise TraceMismatchError("op different from last time")
raise TraceMismatchError("op different from last time")
if len(ihandles) != len(args):
raise TraceMismatchError("op input size different from last time")



+ 6
- 17
imperative/python/megengine/module/conv.py View File

@@ -10,7 +10,6 @@ from typing import Tuple, Union

import numpy as np

from ..core.ops._internal import param_defs as P
from ..functional import conv1d, conv2d, conv_transpose2d, local_conv2d, relu
from ..functional.types import _pair, _pair_nonzero
from ..tensor import Parameter
@@ -156,8 +155,6 @@ class Conv1d(_ConvNd):
(2, 1, 2)

"""
_conv_mode_type = P.Convolution.Mode
_compute_mode_type = P.Convolution.ComputeMode

def __init__(
self,
@@ -176,8 +173,8 @@ class Conv1d(_ConvNd):
stride = stride
padding = padding
dilation = dilation
self.conv_mode = self._conv_mode_type.convert(conv_mode)
self.compute_mode = self._compute_mode_type.convert(compute_mode)
self.conv_mode = conv_mode
self.compute_mode = compute_mode
super().__init__(
in_channels,
out_channels,
@@ -302,9 +299,6 @@ class Conv2d(_ConvNd):

"""

_conv_mode_type = P.Convolution.Mode
_compute_mode_type = P.Convolution.ComputeMode

def __init__(
self,
in_channels: int,
@@ -322,8 +316,8 @@ class Conv2d(_ConvNd):
stride = _pair_nonzero(stride)
padding = _pair(padding)
dilation = _pair_nonzero(dilation)
self.conv_mode = self._conv_mode_type.convert(conv_mode)
self.compute_mode = self._compute_mode_type.convert(compute_mode)
self.conv_mode = conv_mode
self.compute_mode = compute_mode
super().__init__(
in_channels,
out_channels,
@@ -414,9 +408,6 @@ class ConvTranspose2d(_ConvNd):
effective when input and output are of float16 dtype.
"""

_conv_mode_type = P.Convolution.Mode
_compute_mode_type = P.Convolution.ComputeMode

def __init__(
self,
in_channels: int,
@@ -434,8 +425,8 @@ class ConvTranspose2d(_ConvNd):
stride = _pair_nonzero(stride)
padding = _pair(padding)
dilation = _pair_nonzero(dilation)
self.conv_mode = self._conv_mode_type.convert(conv_mode)
self.compute_mode = self._compute_mode_type.convert(compute_mode)
self.conv_mode = conv_mode
self.compute_mode = compute_mode
super().__init__(
in_channels,
out_channels,
@@ -509,8 +500,6 @@ class LocalConv2d(Conv2d):
in_channels // groups, *kernel_size, out_channels // groups)`.
"""

_conv_mode_type = P.Convolution.Mode

def __init__(
self,
in_channels: int,


+ 0
- 1
imperative/python/megengine/module/elemwise.py View File

@@ -5,7 +5,6 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ..core.ops._internal import param_defs as P
from ..functional.elemwise import _elwise
from ..tensor import Tensor
from .module import Module


+ 2
- 2
imperative/python/megengine/module/qat/conv.py View File

@@ -41,8 +41,8 @@ class Conv2d(Float.Conv2d, QATModule):
float_module.dilation,
float_module.groups,
float_module.bias is not None,
float_module.conv_mode.name,
float_module.compute_mode.name,
float_module.conv_mode,
float_module.compute_mode,
)
qat_module.weight = float_module.weight
qat_module.bias = float_module.bias


+ 1
- 4
imperative/python/megengine/module/quantized/elemwise.py View File

@@ -5,7 +5,6 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...core.ops._internal import param_defs as P
from ...functional.elemwise import _elemwise_multi_type
from ...tensor import Tensor
from ..qat import elemwise as QAT
@@ -15,11 +14,9 @@ from .module import QuantizedModule
class Elemwise(QuantizedModule):
r"""Quantized version of :class:`~.qat.elemwise.Elemwise`."""

_elemwise_multi_type_mode = P.ElemwiseMultiType.Mode

def __init__(self, method, dtype=None):
super().__init__()
self.method = self._elemwise_multi_type_mode.convert("Q" + method)
self.method = "Q" + method
self.output_dtype = dtype

def forward(self, *inps):


+ 2
- 2
imperative/python/megengine/utils/profiler.py View File

@@ -15,7 +15,7 @@ from typing import Iterable, List, Optional
from ..core._imperative_rt import OperatorNodeConfig, ProfileEntry
from ..core._imperative_rt import ProfilerImpl as _Profiler
from ..core._imperative_rt.imperative import sync
from ..core._imperative_rt.ops import CollectiveCommMode
from ..core._imperative_rt.ops import CollectiveComm


def _make_dict(**kwargs):
@@ -194,7 +194,7 @@ class Profiler:
_type_map = {
OperatorNodeConfig: lambda x: _print_opnode_config(x),
bytes: lambda x: base64.encodebytes(x).decode("ascii"),
CollectiveCommMode: lambda x: str(x),
CollectiveComm.Mode: lambda x: str(x),
}

_dumper_map = {


+ 1
- 3
imperative/python/src/graph_rt.cpp View File

@@ -421,9 +421,7 @@ void init_graph_rt(py::module m) {

common.def("invoke_op", [](const OpDef& def, const std::vector<cg::VarNode*> inputs, cg::ComputingGraph* graph) {
cg::VarNodeArray vinputs(inputs.begin(), inputs.end());
auto opr = OpDef::apply_on_var_node(def, vinputs);
auto outputs = opr->usable_output();
return to_tuple(outputs);
return to_tuple(OpDef::apply_on_var_node(def, vinputs));
},
py::arg(), py::arg(), py::arg("graph") = py::none());



+ 0
- 3
imperative/python/src/imperative_rt.cpp View File

@@ -109,9 +109,6 @@ void init_imperative_rt(py::module m) {

py::class_<OpDef, std::shared_ptr<OpDef>>(m, "OpDef")
.def("ctype", [](const OpDef& opdef) {
if (auto attr = opdef.try_cast_final<OprAttr>()) {
return attr->type.c_str();
}
return opdef.dyn_typeinfo()->name;
})
.def("__eq__", [](const OpDef& lhs, const OpDef& rhs) {


+ 13
- 179
imperative/python/src/ops.cpp View File

@@ -14,41 +14,29 @@
#include "megbrain/imperative.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/tensor_manip.h"
#include "megbrain/imperative/ops/collective_comm.h"
#include "megbrain/imperative/ops/io_remote.h"
#include "megbrain/imperative/ops/cond_take.h"
#include "megbrain/imperative/ops/nms.h"
#include "megbrain/imperative/ops/elemwise.h"
#include "megbrain/imperative/ops/batch_norm.h"
#include "megbrain/imperative/ops/broadcast.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/ops/autogen.h"

namespace py = pybind11;

namespace {
auto normalize_enum(const std::string& in) {
std::string ret;
for (auto&& c : in) {
ret += toupper(c);
}
return ret;
}
} // anonymous namespace

void init_ops(py::module m) {
using namespace mgb::imperative;

py::class_<OprAttr, std::shared_ptr<OprAttr>, OpDef>(m, "OprAttr")
.def(py::init<>())
.def_readwrite("type", &OprAttr::type)
.def_readwrite("param", &OprAttr::param)
.def_readwrite("config", &OprAttr::config)
.def_property("param",
[](const OprAttr& attr) -> py::bytes {
return std::string(attr.param.begin(), attr.param.end());
},
[] (OprAttr& attr, py::bytes data) {
auto s = py::cast<std::string>(data);
attr.param.clear();
attr.param.insert(attr.param.end(), s.begin(), s.end());
});

py::class_<BackwardGraph, std::shared_ptr<BackwardGraph>, OpDef>(m, "BackwardGraph")
.def("interpret", [](BackwardGraph& self, py::object pyf, py::object pyc,
const mgb::SmallVector<py::object>& inputs) {
auto f = [pyf](OpDef& op, const mgb::SmallVector<py::object>& inputs) {
return py::cast<mgb::SmallVector<py::object>>(pyf(op.copy(), inputs));
return py::cast<mgb::SmallVector<py::object>>(pyf(op.shared_from_this(), inputs));
};
auto c = [pyc](const TensorPtr& tensor) {
return pyc(tensor->dev_tensor());
@@ -56,162 +44,8 @@ void init_ops(py::module m) {
return self.graph().interpret<py::object>(f, c, inputs);
});

py::class_<GetVarShape, std::shared_ptr<GetVarShape>, OpDef>(m, "GetVarShape")
.def(py::init());

#define V(m) .value(#m, CollectiveComm::Mode::m)
py::enum_<CollectiveComm::Mode>(m, "CollectiveCommMode")
V(REDUCE_SUM)
V(BROADCAST)
V(ALL_GATHER)
V(REDUCE_SCATTER_SUM)
V(ALL_REDUCE_SUM)
V(ALL_REDUCE_MAX)
V(ALL_REDUCE_MIN)
V(ALL_REDUCE_PROD)
V(GATHER)
V(SCATTER)
V(ALL_TO_ALL);
#undef V

py::class_<CollectiveComm, std::shared_ptr<CollectiveComm>, OpDef>(m, "CollectiveComm")
.def(py::init<>())
.def_readwrite("key", &CollectiveComm::key)
.def_readwrite("nr_devices", &CollectiveComm::nr_devices)
.def_readwrite("rank", &CollectiveComm::rank)
.def_readwrite("is_root", &CollectiveComm::is_root)
.def_readwrite("local_grad", &CollectiveComm::local_grad)
.def_readwrite("addr", &CollectiveComm::addr)
.def_readwrite("port", &CollectiveComm::port)
.def_readwrite("mode", &CollectiveComm::mode)
.def_readwrite("dtype", &CollectiveComm::dtype)
.def_readwrite("backend", &CollectiveComm::backend)
.def_readwrite("comp_node", &CollectiveComm::comp_node);

py::class_<RemoteSend, std::shared_ptr<RemoteSend>, OpDef>(m, "RemoteSend")
.def(py::init<>())
.def_readwrite("key", &RemoteSend::key)
.def_readwrite("addr", &RemoteSend::addr)
.def_readwrite("port", &RemoteSend::port)
.def_readwrite("rank_to", &RemoteSend::rank_to);

py::class_<RemoteRecv, std::shared_ptr<RemoteRecv>, OpDef>(m, "RemoteRecv")
.def(py::init<>())
.def_readwrite("key", &RemoteRecv::key)
.def_readwrite("addr", &RemoteRecv::addr)
.def_readwrite("port", &RemoteRecv::port)
.def_readwrite("rank_from", &RemoteRecv::rank_from)
.def_readwrite("shape", &RemoteRecv::shape)
.def_readwrite("cn", &RemoteRecv::cn)
.def_readwrite("dtype", &RemoteRecv::dtype);

py::class_<ParamPackSplit, std::shared_ptr<ParamPackSplit>, OpDef>(m, "ParamPackSplit")
.def(py::init<>())
.def_readwrite("offsets", &ParamPackSplit::offsets)
.def_readwrite("shapes", &ParamPackSplit::shapes);

py::class_<ParamPackConcat, std::shared_ptr<ParamPackConcat>, OpDef>(m, "ParamPackConcat")
.def(py::init<>())
.def_readwrite("offsets", &ParamPackConcat::offsets);

py::class_<VirtualDep, std::shared_ptr<VirtualDep>, OpDef>(m, "VirtualDep")
.def(py::init<>());

py::class_<CondTake, std::shared_ptr<CondTake>, OpDef>(m, "CondTake")
.def(py::init<>());

py::class_<NMSKeep, std::shared_ptr<NMSKeep>, OpDef>(m, "NMSKeep")
.def(py::init<float, uint32_t>())
.def_readwrite("iou_thresh", &NMSKeep::iou_thresh)
.def_readwrite("max_output", &NMSKeep::max_output);

py::class_<Elemwise, std::shared_ptr<Elemwise>, OpDef> elemwise(m, "Elemwise");
elemwise.def(py::init<Elemwise::Mode>())
.def_readwrite("mode", &Elemwise::mode);

#define V(m) .value(#m, Elemwise::Mode::m)
py::enum_<Elemwise::Mode>(elemwise, "Mode")
V(RELU)
V(ABS)
V(ACOS)
V(ASIN)
V(CEIL)
V(COS)
V(EXP)
V(EXPM1)
V(FLOOR)
V(LOG)
V(LOG1P)
V(NEGATE)
V(SIGMOID)
V(SIN)
V(TANH)
V(ABS_GRAD)
V(ADD)
V(FLOOR_DIV)
V(MAX)
V(MIN)
V(MOD)
V(MUL)
V(POW)
V(SIGMOID_GRAD)
V(SUB)
V(SWITCH_GT0)
V(TANH_GRAD)
V(TRUE_DIV)
V(LOG_SUM_EXP)
V(LT)
V(LEQ)
V(EQ)
V(SHL)
V(SHR)
V(COND_LEQ_MOV)
V(FUSE_MUL_ADD3)
V(FUSE_MUL_ADD4)
V(FUSE_ADD_RELU)
V(FUSE_ADD_SIGMOID)
V(FUSE_ADD_TANH)
V(FAST_TANH)
V(FAST_TANH_GRAD)
V(ROUND)
V(RMULH)
V(ATAN2)
V(ERF)
V(ERFINV)
V(ERFC)
V(ERFCINV)
V(H_SWISH)
V(H_SWISH_GRAD)
V(FUSE_ADD_H_SWISH)
V(NOT)
V(AND)
V(OR)
V(XOR);
#undef V

py::class_<BatchNorm, std::shared_ptr<BatchNorm>, OpDef> batchnorm(m, "BatchNorm");
batchnorm.def(py::init<const BatchNorm::Param::ParamDim&, const BatchNorm::Param::FwdMode&, double, double, float, float>())
.def_readwrite("param_dim", &BatchNorm::param_dim)
.def_readwrite("fwd_mode", &BatchNorm::fwd_mode)
.def_readwrite("epsilon", &BatchNorm::epsilon)
.def_readwrite("avg_factor", &BatchNorm::avg_factor)
.def_readwrite("scale", &BatchNorm::scale)
.def_readwrite("bias", &BatchNorm::bias);

#define V(m) .value(#m, BatchNorm::Param::ParamDim::m)
py::enum_<BatchNorm::Param::ParamDim>(batchnorm, "ParamDim")
V(DIM_11HW)
V(DIM_1CHW)
V(DIM_1C11);
#undef V

#define V(m) .value(#m, BatchNorm::Param::FwdMode::m)
py::enum_<BatchNorm::Param::FwdMode>(batchnorm, "FwdMode")
V(TRAINING)
V(INFERENCE);
#undef V

py::class_<Broadcast, std::shared_ptr<Broadcast>, OpDef>(m, "Broadcast")
.def(py::init<>());

#include "opdef.py.inl"
}

+ 2
- 2
imperative/python/test/unit/core/test_dtype_quant.py View File

@@ -113,7 +113,7 @@ def test_quint8_typecvt():
data = np.random.random(shape).astype(np.float32) * 5 - 1

def typecvt(x, dt=None):
(y,) = apply(ops.TypeCvt(param=dt), x)
(y,) = apply(ops.TypeCvt(dtype=dt), x)
return y

# convert to quint8
@@ -194,7 +194,7 @@ def test_quint4_typecvt():
data = np.random.random(shape).astype(np.float32) * 5 - 1

def typecvt(x, dt=None):
(y,) = apply(ops.TypeCvt(param=dt), x)
(y,) = apply(ops.TypeCvt(dtype=dt), x)
return y

# convert to quint4


+ 17
- 18
imperative/python/test/unit/core/test_indexing_op.py View File

@@ -11,10 +11,9 @@ import collections
import numpy as np
import pytest

import megengine.core.ops.builtin
import megengine.core.tensor.raw_tensor
from megengine.core._trace_option import use_symbolic_shape
from megengine.core.ops._internal import all_ops
from megengine.core.ops import builtin
from megengine.core.tensor import Tensor
from megengine.core.tensor.core import apply
from megengine.core.tensor.raw_tensor import RawTensor, as_raw_tensor
@@ -105,7 +104,7 @@ def canonize_inputs(inputs, *, config):
need_cvt = False
for i in old_inputs:
if isinstance(i, RawTensor):
get_comp_node = lambda cn=i.device.to_c(): cn
get_comp_node = lambda cn=i.device: cn
else:
need_cvt = True
inputs.append(i)
@@ -193,91 +192,91 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):


def transpose(*args, **kwargs):
op = all_ops.Dimshuffle(**kwargs).to_c()
op = builtin.Dimshuffle(**kwargs)
return invoke_op(op, args)


def broadcast(input, tshape):
op = all_ops.Broadcast().to_c()
op = builtin.Broadcast()
return invoke_op(op, (input, tshape), canonize_reshape)


def subtensor(input, tuple_val):
input, tensors, items = unpack_getitem(input, tuple_val)
op = all_ops.Subtensor(items).to_c()
op = builtin.Subtensor(items)
return invoke_op(op, (input, *tensors))


def set_subtensor(input, value, tuple_val):
input, tensors, items = unpack_getitem(input, tuple_val)
op = all_ops.SetSubtensor(items).to_c()
op = builtin.SetSubtensor(items)
return invoke_op(op, (input, value, *tensors))


def incr_subtensor(input, value, tuple_val):
input, tensors, items = unpack_getitem(input, tuple_val)
op = all_ops.IncrSubtensor(items).to_c()
op = builtin.IncrSubtensor(items)
return invoke_op(op, (input, value, *tensors))


def advance_indexing(input, tuple_val):
input, tensors, items = unpack_getitem(input, tuple_val)
op = all_ops.IndexingMultiAxisVec(items).to_c()
op = builtin.IndexingMultiAxisVec(items)
return invoke_op(op, (input, *tensors))


def set_advance_indexing(input, value, tuple_val):
input, tensors, items = unpack_getitem(input, tuple_val)
op = all_ops.IndexingSetMultiAxisVec(items).to_c()
op = builtin.IndexingSetMultiAxisVec(items)
return invoke_op(op, (input, value, *tensors))


def incr_advance_indexing(input, value, tuple_val):
input, tensors, items = unpack_getitem(input, tuple_val)
op = all_ops.IndexingIncrMultiAxisVec(items).to_c()
op = builtin.IndexingIncrMultiAxisVec(items)
return invoke_op(op, (input, value, *tensors))


def mesh_indexing(input, tuple_val):
input, tensors, items = unpack_getitem(input, tuple_val)
op = all_ops.MeshIndexing(items).to_c()
op = builtin.MeshIndexing(items)
return invoke_op(op, (input, *tensors))


def set_mesh_indexing(input, value, tuple_val):
input, tensors, items = unpack_getitem(input, tuple_val)
op = all_ops.SetMeshIndexing(items).to_c()
op = builtin.SetMeshIndexing(items)
return invoke_op(op, (input, value, *tensors))


def incr_mesh_indexing(input, value, tuple_val):
input, tensors, items = unpack_getitem(input, tuple_val)
op = all_ops.IncrMeshIndexing(items).to_c()
op = builtin.IncrMeshIndexing(items)
return invoke_op(op, (input, value, *tensors))


def batched_mesh_indexing(input, tuple_val):
input, tensors, items = unpack_getitem(input, tuple_val)
op = all_ops.BatchedMeshIndexing(items).to_c()
op = builtin.BatchedMeshIndexing(items)
return invoke_op(op, (input, *tensors))


def batched_set_mesh_indexing(input, value, tuple_val):
input, tensors, items = unpack_getitem(input, tuple_val)
op = all_ops.BatchedSetMeshIndexing(items).to_c()
op = builtin.BatchedSetMeshIndexing(items)
return invoke_op(op, (input, value, *tensors))


def batched_incr_mesh_indexing(input, value, tuple_val):
input, tensors, items = unpack_getitem(input, tuple_val)
op = all_ops.BatchedIncrMeshIndexing(items).to_c()
op = builtin.BatchedIncrMeshIndexing(items)
return invoke_op(op, (input, value, *tensors))


def test_transpose():
x = np.arange(10).reshape(2, 5).astype("int32")
xx = as_raw_tensor(x)
(yy,) = transpose(xx, pattern="1x0")
(yy,) = transpose(xx, pattern=[1, -1, 0])
np.testing.assert_equal(np.expand_dims(x.transpose(), axis=1), yy.numpy())




+ 0
- 320
imperative/python/tools/gen_ops.py View File

@@ -1,320 +0,0 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from io import StringIO
import re
import argparse
import subprocess
import os
import textwrap
import inspect

def camel2underscore(
name, *,
first_cap_re=re.compile('([A-Z])([A-Z][a-z]+)'),
all_cap_re = re.compile('([a-z])([A-Z]+)')):
if name.isupper():
return name.lower()
s1 = first_cap_re.sub(r'\1_\2', name)
return all_cap_re.sub(r'\1_\2', s1).lower()


def caller_lineno(level=1):
f = inspect.stack()[level+1]
return '%s:%d' % (f.filename, f.lineno)


class Doc:
"""wrap an identifier and doc"""
_id = None

def __init__(self, id_, doc, typestr=None, default=None):
self._id = id_
self.doc = doc
self.typestr = typestr
self.default = default

def __str__(self):
return self._id


class Context:
fout = None

def __init__(self):
self.fout = StringIO()
self.indent = 0
self.skipped = []
self.generated_signature = set()
self.generated_opr = dict()

def write(self, text, *fmt, indent=0):
text = textwrap.dedent(text)
text = textwrap.indent(text, ' '*4*(self.indent + indent))
text = text % fmt
if not text.endswith('\n'):
text += '\n'
self.fout.write(text)

def _gen_signature(self, params, *, have_config=True,
has_out_dtype=False):
sig = ['self', '*']

for i, _ in params:
sig.append('{}=None'.format(i))

if have_config:
sig.extend(['name=None', 'comp_node=None', 'config=None'])
if has_out_dtype:
sig.append('dtype=None')

if params:
sig.append('**kwargs')

if sig[-1] == '*':
sig.pop()
return ', '.join(sig)

def _write_canonize_inputs(self, inputs, convert_inputs,
convert_inputs_args=None,
has_out_dtype=False):
self._write_gen_config(has_out_dtype)
inputs = list(map(str, inputs))
if convert_inputs_args is None:
if inputs[0][0] == '*':
arg = inputs[0][1:]
else:
arg = '[{}]'.format(', '.join(inputs))
else:
arg = convert_inputs_args
self.write('inputs = helper.%s(%s, config=config)',
convert_inputs, arg)

def _write_gen_config(self, has_out_dtype=False):
self.write('''\
config = config or Config()
if name:
config.name = name
if comp_node:
config.comp_node = comp_node
''')
if has_out_dtype:
self.write('''\
if dtype:
config.dtype = dtype
''')
self.write('self.config = config')

def _write_make_params(self, params):
for pname, ptype in params:
self.write('self.%s = helper.make_param(%s, param_defs.%s, kwargs)',
pname, pname, ptype)
self.write('assert not kwargs, "extra kwargs: {}".format(kwargs)')

def _write_doc(self, inputs, params, desc):
self.write('"""')
if isinstance(desc, Doc):
assert desc._id is None
self.write(desc.doc)
elif desc:
for i in textwrap.wrap(desc, 75):
self.write(i)

self.write('')
for i in inputs:
name = str(i)
typestr = ':class:`.Tensor`'
if name[0] == '*':
name = name[1:]
typestr = 'list of ' + typestr
if isinstance(i, Doc):
self.write(':param %s: %s', name, i.doc)
if i.typestr is not None:
typestr = i.typestr
if typestr:
if not isinstance(i, Doc):
self.write(':param %s: ', name)
self.write(':type %s: %s', name, typestr)

for pname, ptype in params:
self.write(':param %s: ', pname)
self.write(':type %s: :class:`~megbrain.opr_param_defs.%s`',
pname, ptype)

self.write(':param comp_node: see doc for *config*')
self.write(':param name: see doc for *config*')
self.write(
':param config: give a :class:`.OperatorNodeConfig` object to set '
'operator name and comp node. This can also be achieved by passing '
'*comp_node* and *name* separately.')

self.write('"""')

def _write_return(self, name, outputs):
self.write('opdef = helper.PodOpVisitor("%s", config, params)', name)
self.write('outputs = helper.create_op(opdef, inputs)')
if outputs:
self.write('outputs = [outputs[i] for i in %s]',
list(map(int, outputs)))
self.write('return helper.convert_outputs(outputs)')

def decl_opr(self, name, *, inputs, params, desc=None, pyname=None,
canonize_input_vars=None,
canonize_input_vars_args=None, body=None,
outputs=None, version=0, has_out_dtype=False):
"""
:param inputs: name of variable inputs; a name starting with `*' means
a list of vars
:type inputs: list of str
:param params: (param name, param type) pairs; it can be a single
string representing the param type, and param name defaults to
'param'
:type params: list of pair of str, or str
:param pyname: python function name
:param body: extra statements to be placed before calling _create_opr
:param outputs: the indices of output vars to be selected from raw opr
result
"""

class OprItem:
def __init__(self, inputs, desc, params, version, has_out_dtype):
self.inputs = inputs
self.desc = desc
self.params = params
self.version = version
self.has_out_dtype = has_out_dtype

if body:
self.skipped.append(name)
return

signature = (name, params if isinstance(params, str) else frozenset(params), has_out_dtype, version)
if signature in self.generated_signature:
self.skipped.append(name)
return
else:
self.generated_signature.add(signature)

body = body or []
if isinstance(params, str):
params = [('param', params)]
assert params

if name in self.generated_opr:
org_opr = self.generated_opr[name]
if version > org_opr.version:
def compare_doc(a, b):
if isinstance(a, str):
return a == b
else:
assert isinstance(a, Doc)
return a.doc == b.doc

assert compare_doc(desc, org_opr.desc)
assert len(inputs) == len(org_opr.inputs)
for i, j in zip(inputs, org_opr.inputs):
assert compare_doc(i, j)

self.generated_opr[name] = OprItem(inputs, desc, params, version, has_out_dtype)
else:
self.generated_opr[name] = OprItem(inputs, desc, params, version, has_out_dtype)

def write_generated_oprs(self):

for opr, opr_item in self.generated_opr.items():

name = opr
params = opr_item.params
version = opr_item.version
has_out_dtype = opr_item.has_out_dtype

self.write('# %s', caller_lineno())
self.write('class %s(PodOpVisitor):', name)
self.indent += 1

param_names, _ = zip(*params)
self.write('param_names = (%s,)', ', '.join(map('"{}"'.format, param_names)))
self.write('name = "%s"', '{}V{}'.format(name, version) if version else name)
self.write('\n')

self.write('def __init__(%s):',
self._gen_signature(params,
has_out_dtype=has_out_dtype))
self.indent += 1

self._write_gen_config(has_out_dtype=has_out_dtype)
self.write('\n')

self._write_make_params(params)

self.write('\n')
self.indent -= 2


def decl_raw_opr(self, name, *, inputs, inputs_cvt=[], body=None,
desc=None, local_defs=[], have_config=True, params=None, has_out_dtype=False):
self.skipped.append(name)

def get_str(self):
return self.fout.getvalue()

def all_list(self):
buf = StringIO()
print(
'[',
*(' "%s",' % i for i in self.generated_opr),
']',
sep='\n',
file=buf
)
return buf.getvalue()


def main():
parser = argparse.ArgumentParser(
description='generate operator function def code from decl file')
parser.add_argument('inputs', nargs='+')
parser.add_argument('--output', '-o')
args = parser.parse_args()

gen = Context()
exec_globals = {
'decl_opr': gen.decl_opr,
'decl_raw_opr': gen.decl_raw_opr,
'Doc': Doc,
'camel2underscore': camel2underscore,
}
for i in args.inputs:
print('generate ops from {}'.format(i))
with open(i) as fin:
exec(compile(fin.read(), i, 'exec'), exec_globals)

gen.write_generated_oprs()
try:
git_commit = subprocess.check_output(
['git', 'rev-parse', 'HEAD'], universal_newlines=True,
cwd=os.path.dirname(os.path.realpath(__file__))).strip()
except:
git_commit = 'NOT_A_GIT_REPO'

def relpath(*args):
d = os.path.dirname(__file__)
return os.path.join(d, *args)

with open(relpath('ops.tpl.py')) as fin:
with open(args.output, 'w') as fout:
fout.write(fin.read()
.replace('{%all%}', gen.all_list())
.replace('{%body%}', gen.get_str())
.replace('{%git_commit%}', git_commit))

print('Skipped:')
print(*gen.skipped, sep='\n')

if __name__ == '__main__':
main()

+ 0
- 40
imperative/python/tools/ops.tpl.py View File

@@ -1,40 +0,0 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
"""This python module contains functions to apply the operators defined by
megbrain.

.. note::
Most of the functions are automatically generated, and their signature have
the form contain a ``param`` argument (or more than one arguments such as
:func:`convolution` that has ``param`` and ``execution_polity``) and also
accept keyword arguments. In such case, it can be called by either
providing a param object of appropriate type, or by passing the arguments
needed by the constructor of param object to the keyword arguments.
Furthermore, for a param that needs an enumeration member, the enum name
can be used to refer to the enum object.

For example, the following statements are equivalent::

elemwise([a, b], mode='max')
elemwise([a, b], mode=opr_param_defs.Elemwise.Mode.MAX)
elemwise([a, b], param=opr_param_defs.Elemwise('max'))
"""

__git_commit__ = "{%git_commit%}"

import collections

from . import helper
from .helper import PodOpVisitor
from . import param_defs
from ..._imperative_rt import OperatorNodeConfig as Config

__all__ = {%all%}

{%body%}

+ 9
- 1
imperative/src/impl/op_def.cpp View File

@@ -36,7 +36,7 @@ SmallVector<TensorPtr> OpDef::apply_on_physical_tensor(
return def.trait()->apply_on_physical_tensor(def, inputs);
}

cg::OperatorNodeBase* OpDef::apply_on_var_node(
VarNodeArray OpDef::apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
return def.trait()->apply_on_var_node(def, inputs);
@@ -56,6 +56,14 @@ BackwardGraphResult OpDef::make_backward_graph(
return def.trait()->make_backward_graph(def, inputs, input_requires_grad, output_has_grad);
}

size_t OpDef::hash() const {
return trait()->hash(*this);
}

bool OpDef::is_same_st(const Hashable& rhs) const {
return trait()->is_same_st(*this, static_cast<const OpDef&>(rhs));
}

const OpTrait* OpDef::trait() const {
if (!m_trait) {
m_trait = OpTrait::find_by_typeinfo(dyn_typeinfo());


+ 1
- 1
imperative/src/impl/op_trait.cpp View File

@@ -23,7 +23,7 @@ namespace detail {

struct StaticData {
std::list<OpTrait> registries;
std::unordered_map<const char*, OpTrait*> name2reg;
std::unordered_map<std::string, OpTrait*> name2reg;
std::unordered_map<Typeinfo*, OpTrait*> type2reg;
};



+ 42
- 1
imperative/src/impl/op_trait.h View File

@@ -30,6 +30,32 @@ struct OpMeth<RType(Args...)>: public thin_function<RType(Args...)> {
return this->Base::operator ()(args...);
}
};
template<typename T>
struct ToVarNodeArray: std::false_type {};
template<>
struct ToVarNodeArray<SymbolVar>: std::true_type {
VarNodeArray operator()(const SymbolVar& inp) {
return {inp.node()};
}
};
template<>
struct ToVarNodeArray<SymbolVarArray>: std::true_type {
VarNodeArray operator()(const SymbolVarArray& inputs) {
return cg::to_var_node_array(inputs);
}
};
template<size_t N>
struct ToVarNodeArray<std::array<SymbolVar, N>>: std::true_type {
VarNodeArray operator()(const std::array<SymbolVar, N>& inp) {
return cg::to_var_node_array({inp.begin(), inp.end()});
}
};
template<>
struct ToVarNodeArray<cg::OperatorNodeBase*>: std::true_type {
VarNodeArray operator()(const cg::OperatorNodeBase* opr) {
return opr->usable_output();
}
};
} // detail

using OpDefMaker = detail::OpMeth<
@@ -42,6 +68,8 @@ using InferOutputAttrsFallible = detail::OpMeth<
decltype(OpDef::infer_output_attrs_fallible)>;
using GradMaker = detail::OpMeth<
decltype(OpDef::make_backward_graph)>;
using HashFunc = detail::OpMeth<size_t(const OpDef&)>;
using IsSame = detail::OpMeth<bool(const OpDef&, const OpDef&)>;

struct OpTrait {
const char* name;
@@ -50,6 +78,8 @@ struct OpTrait {
ApplyOnVarNode apply_on_var_node;
InferOutputAttrsFallible infer_output_attrs_fallible;
GradMaker make_backward_graph;
HashFunc hash;
IsSame is_same_st;
OpTrait(const char* name);
static OpTrait* find_by_name(const char* name);
static OpTrait* find_by_typeinfo(Typeinfo* type);
@@ -61,7 +91,9 @@ struct OpTrait {
cb(apply_on_physical_tensor) \
cb(apply_on_var_node) \
cb(infer_output_attrs_fallible) \
cb(make_backward_graph)
cb(make_backward_graph) \
cb(hash) \
cb(is_same_st)

struct OpTraitRegistry {
OpTrait* trait;
@@ -97,6 +129,15 @@ struct OpTraitRegistry {
void do_insert(Typeinfo* type);

static OpTraitRegistry do_insert(const char* name);

template<typename T,
typename To = detail::ToVarNodeArray<T>,
typename = std::enable_if_t<To::value>>
OpTraitRegistry& apply_on_var_node(T (*f)(const OpDef&, const VarNodeArray&)) {
return apply_on_var_node([=](const OpDef& opdef, const VarNodeArray& inputs) {
return To()(f(opdef, inputs));
});
}
};

} // namespace imperative


+ 46
- 0
imperative/src/impl/ops/autogen.cpp View File

@@ -0,0 +1,46 @@
/**
* \file imperative/src/impl/ops/autogen.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "megbrain/imperative/ops/autogen.h"

#include "../op_trait.h"

using namespace megdnn;

// FIXME: remove this when mgb::hash support tuple_hash
namespace mgb {
namespace {
template<typename T, size_t ...Ns>
auto tail(T t, std::index_sequence<Ns...>) {
return std::make_tuple(std::get<Ns+1>(t)...);
}
} // anonymous namespace
template<typename T, typename ...Args>
class HashTrait<std::tuple<T, Args...>> {
constexpr static size_t length = sizeof...(Args);
public:
static size_t eval(const std::tuple<T, Args...> &t) {
const T& val = std::get<0>(t);
if constexpr (!length) {
return mgb::hash(val);
} else {
return mgb::hash_pair_combine(mgb::hash(val),
mgb::hash(tail(t, std::make_index_sequence<length - 1>{})));
}
}
};
} // namespace mgb

namespace mgb::imperative {

#include "./opdef.cpp.inl"

} // namespace mgb::imperative

+ 6
- 11
imperative/src/impl/ops/batch_norm.cpp View File

@@ -9,7 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "megbrain/imperative/ops/batch_norm.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/dnn/batch_norm.h"
#include "../op_trait.h"

namespace mgb {
@@ -19,9 +20,7 @@ namespace {

std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
auto* node = &node_->cast_final_safe<opr::BatchNorm>();
auto&& param = node->param();
return BatchNorm::make(param.param_dim, param.fwd_mode, param.epsilon,
param.avg_factor, param.scale, param.bias);
return BatchNorm::make(node->param());
}

cg::OperatorNodeBase* apply_on_var_node(
@@ -33,13 +32,11 @@ cg::OperatorNodeBase* apply_on_var_node(
"BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp);
if (nr_inp == 3) {
return opr::BatchNorm::make(
inputs[0], inputs[1], inputs[2],
{bn_opr.param_dim, bn_opr.fwd_mode, bn_opr.epsilon, bn_opr.avg_factor, bn_opr.scale, bn_opr.bias})[0]
inputs[0], inputs[1], inputs[2], bn_opr.param())[0]
.node()->owner_opr();
} else {
return opr::BatchNorm::make(
inputs[0], inputs[1], inputs[2], inputs[3], inputs[4],
{bn_opr.param_dim, bn_opr.fwd_mode, bn_opr.epsilon, bn_opr.avg_factor, bn_opr.scale, bn_opr.bias})[0]
inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], bn_opr.param())[0]
.node()->owner_opr();
}
}
@@ -52,7 +49,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
mgb_assert(nr_inp == 3 ||nr_inp == 5,
"BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp);
// need running mean/variance
bool need_stat = (nr_inp == 5) && op_def.fwd_mode == BatchNorm::Param::FwdMode::TRAINING;
bool need_stat = (nr_inp == 5) && op_def.fwd_mode == BatchNorm::FwdMode::TRAINING;
size_t nr_out = need_stat? 5 : 3;
SmallVector<LogicalTensorDesc> out_shapes(nr_out);
auto&& i0 = inputs[0];
@@ -76,8 +73,6 @@ OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm)
.fallback();
} // anonymous namespace

MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchNorm);

} // namespace imperative
} // namespace mgb



+ 3
- 3
imperative/src/impl/ops/broadcast.cpp View File

@@ -9,7 +9,9 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "megbrain/imperative/ops/broadcast.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/tensor_manip.h"

#include "../op_trait.h"

namespace mgb {
@@ -87,8 +89,6 @@ OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast)
.fallback();
} // anonymous namespace

MGB_DYN_TYPE_OBJ_FINAL_IMPL(Broadcast);

} // namespace imperative
} // namespace mgb



+ 3
- 32
imperative/src/impl/ops/collective_comm.cpp View File

@@ -18,7 +18,7 @@
#include "megbrain/utils/hash.h"
#endif // MGB_ENABLE_OPR_MM

#include "megbrain/imperative/ops/collective_comm.h"
#include "megbrain/imperative/ops/autogen.h"

namespace mgb {
namespace imperative {
@@ -61,8 +61,8 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node) {
auto [addr, port] = split_address(group_client->get_addr());
auto comp_node = node->config().get_single_comp_node().to_string_logical();
return std::make_shared<CollectiveComm>(
comm.key(), comm.nr_devices(), comm.rank(), comm.is_root(),
comm.local_grad(), addr, std::stoi(port), comm.param().mode,
comm.param().mode, comm.key(), comm.nr_devices(), comm.rank(),
comm.is_root(), comm.local_grad(), addr, std::stoi(port),
comm.dtype(), comm.backend(), comp_node);
}

@@ -73,35 +73,6 @@ OP_TRAIT_REG(CollectiveComm, CollectiveComm, opr::CollectiveComm)
} // anonymous namespace
#endif // MGB_ENABLE_OPR_MM

bool CollectiveComm::is_same_st(const Hashable& another) const{
auto* comm_opr = another.try_cast_final<CollectiveComm>();
if(!comm_opr){
return false;
}
return as_tuple() == comm_opr->as_tuple();
}

size_t CollectiveComm::hash() const{
XXHash xxhash{};
auto append = [&xxhash](auto field){
auto hash_val = HashTrait<decltype(field)>::eval(field);
xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val));
};
append(key);
append(nr_devices);
append(rank);
append(is_root);
append(local_grad);
append(addr);
append(port);
append(mode);
append(backend);
append(comp_node);
return xxhash.digest();
}

MGB_DYN_TYPE_OBJ_FINAL_IMPL(CollectiveComm);

} // namespace imperative
} // namespace mgb



+ 1
- 4
imperative/src/impl/ops/cond_take.cpp View File

@@ -9,8 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "megbrain/imperative/ops/cond_take.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/misc.h"
#include "../dnn_op_helper.h"
#include "../op_trait.h"
@@ -19,8 +18,6 @@ using namespace megdnn;

namespace mgb::imperative {

MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondTake);

namespace {

class MegDNNDynOutMallocImpl final: public megdnn::DynOutMallocPolicy {


+ 4
- 4
imperative/src/impl/ops/elemwise.cpp View File

@@ -9,7 +9,9 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "megbrain/imperative/ops/elemwise.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/basic_arith.h"

#include "../op_trait.h"

namespace mgb {
@@ -33,7 +35,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) {
auto&& op_def = def.cast_final_safe<Elemwise>();
auto trait = Elemwise::ModeTrait::from_mode(op_def.mode);
auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode);
mgb_assert(inputs.size() == trait.arity,
"%s expects %u inputs; got %zu actually", trait.name,
trait.arity, inputs.size());
@@ -70,8 +72,6 @@ OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise)
.fallback();
} // anonymous namespace

MGB_DYN_TYPE_OBJ_FINAL_IMPL(Elemwise);

} // namespace imperative
} // namespace mgb



+ 1
- 41
imperative/src/impl/ops/io_remote.cpp View File

@@ -18,7 +18,7 @@
#include "megbrain/opr/mm_handler.h"
#endif // MGB_ENABLE_OPR_MM

#include "megbrain/imperative/ops/io_remote.h"
#include "megbrain/imperative/ops/autogen.h"

namespace mgb {
namespace imperative {
@@ -60,45 +60,5 @@ OP_TRAIT_REG(RemoteRecv, RemoteRecv, mgb::opr::RemoteRecv)
} // anonymous namespace
#endif // MGB_ENABLE_OPR_MM

bool RemoteSend::is_same_st(const Hashable& another) const{
return as_tuple() == another.cast_final<RemoteSend>().as_tuple();
}

size_t RemoteSend::hash() const{
XXHash xxhash;
auto append = [&xxhash](auto field){
auto hash_val = HashTrait<decltype(field)>::eval(field);
xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val));
};
append(key);
append(addr);
append(port);
append(rank_to);
return xxhash.digest();
}

bool RemoteRecv::is_same_st(const Hashable& another) const{
return as_tuple() == another.cast_final<RemoteRecv>().as_tuple();
}

size_t RemoteRecv::hash() const{
XXHash xxhash;
auto append = [&xxhash](auto field){
auto hash_val = HashTrait<decltype(field)>::eval(field);
xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val));
};
append(key);
append(addr);
append(port);
append(rank_from);
append(cn.to_string());
append(dtype.handle());
append(shape.to_string());
return xxhash.digest();
}

MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv);

} // namespace imperative
} // namespace mgb

+ 1
- 3
imperative/src/impl/ops/nms.cpp View File

@@ -11,7 +11,7 @@

#include "../op_trait.h"

#include "megbrain/imperative/ops/nms.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/standalone/nms_opr.h"

namespace mgb {
@@ -37,8 +37,6 @@ OP_TRAIT_REG(NMSKeep, NMSKeep, NMSKeepOpr)
.fallback();
} // anonymous namespace

MGB_DYN_TYPE_OBJ_FINAL_IMPL(NMSKeep);

} // namespace imperative
} // namespace mgb



+ 630
- 0
imperative/src/impl/ops/specializations.cpp View File

@@ -0,0 +1,630 @@
/**
* \file imperative/src/impl/ops/autogen.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

// FIXME: split this file into separate files for each specialized op

#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/adaptive_pooling.h"
#include "megbrain/opr/dnn/fake_quant.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/dnn/roi_align.h"
#include "megbrain/opr/dnn/roi_pooling.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/indexing.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/misc.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/opr/rand.h"
#include "megbrain/opr/tensor_gen.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"

#include "../op_trait.h"

namespace mgb::imperative {

namespace { namespace convolution {
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
auto* node = &node_->cast_final_safe<opr::Convolution>();
return Convolution::make(node->param(), node->execution_policy());
}

auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& conv = static_cast<const Convolution&>(def);
return opr::Convolution::make(inputs[0], inputs[1], conv.param(), conv.policy());
}

OP_TRAIT_REG(Convolution, Convolution, opr::Convolution)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // convolution

namespace { namespace convolution_backward_data {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& conv = static_cast<const ConvolutionBackwardData&>(def);
cg::OperatorNodeConfig config;
if (inputs.size() == 2) {
return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], conv.param(), conv.policy(), config);
} else {
mgb_assert(inputs.size() == 3);
return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config);
}
}

OP_TRAIT_REG(ConvolutionBackwardData, ConvolutionBackwardData)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // convolution_backward_data

namespace { namespace dimshuffle {
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
auto* node = &node_->cast_final_safe<opr::Dimshuffle>();
std::vector<int> pattern(node->param().pattern_len);
for (size_t i = 0; i < node->param().pattern_len; ++ i) {
pattern[i] = node->param().pattern[i];
}
return Dimshuffle::make(pattern);
}

auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& ds = static_cast<const Dimshuffle&>(def);
return opr::Dimshuffle::make(inputs[0], ds.pattern);
}

OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // dimshuffle

namespace { namespace add_axis {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& add_axis = static_cast<const AddAxis&>(def);
using Desc = opr::AxisAddRemove::AxisDesc;
std::vector<Desc> param;
for (auto&& i : add_axis.axis) {
param.push_back(Desc::make_add(i));
}
return opr::AxisAddRemove::make(inputs[0], param);
}

OP_TRAIT_REG(AddAxis, AddAxis)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // add_axis

namespace { namespace remove_axis {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& remove_axis = static_cast<const RemoveAxis&>(def);
using Desc = opr::AxisAddRemove::AxisDesc;
std::vector<Desc> param;
for (auto&& i : remove_axis.axis) {
param.push_back(Desc::make_remove(i));
}
return opr::AxisAddRemove::make(inputs[0], param);
}

OP_TRAIT_REG(RemoveAxis, RemoveAxis)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // remove_axis

namespace { namespace top_k {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& topk = static_cast<const TopK&>(def);
return opr::TopK::make(inputs[0], inputs[1], topk.param())[0]
.node()->owner_opr();
}

OP_TRAIT_REG(TopK, TopK)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // top_k

namespace { namespace reduce {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& reduce = static_cast<const Reduce&>(def);
if (inputs.size() > 1) {
return opr::Reduce::make(inputs[0], reduce.param(), inputs[1]);
} else {
return opr::Reduce::make(inputs[0], reduce.param());
}
}

OP_TRAIT_REG(Reduce, Reduce)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // reduce

namespace { namespace adaptive_pooling {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& pool = static_cast<const AdaptivePooling&>(def);
return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param());
}

OP_TRAIT_REG(AdaptivePooling, AdaptivePooling)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // adaptive_pooling

namespace { namespace conv_bias {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& conv = static_cast<const ConvBias&>(def);
cg::OperatorNodeConfig config{conv.dtype};
if (inputs.size() == 2) {
return opr::ConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config);
} else if (inputs.size() == 3) {
return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config);
} else if (inputs.size() == 4) {
return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(), config);
}
mgb_assert(0);
}

OP_TRAIT_REG(ConvBias, ConvBias)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // conv_bias

namespace { namespace batch_conv_bias {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& conv = static_cast<const BatchConvBias&>(def);
cg::OperatorNodeConfig config{conv.dtype};
if (inputs.size() == 2) {
return opr::BatchConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config);
} else if (inputs.size() == 3) {
return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config);
} else if (inputs.size() == 4) {
return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(), config);
}
mgb_assert(0);
}

OP_TRAIT_REG(BatchConvBias, BatchConvBias)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // batch_conv_bias

namespace { namespace pooling {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& pool = static_cast<const Pooling&>(def);
return opr::Pooling::make(inputs[0], pool.param());
}
OP_TRAIT_REG(Pooling, Pooling)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // pooling

namespace { namespace matrix_mul {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& matmul = static_cast<const MatrixMul&>(def);
mgb_assert(inputs.size() == 2);
return opr::MatrixMul::make(inputs[0], inputs[1], matmul.param());
}
OP_TRAIT_REG(MatrixMul, MatrixMul)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // matrix_mul

namespace { namespace batched_matrix_mul {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& matmul = static_cast<const BatchedMatrixMul&>(def);
mgb_assert(inputs.size() == 2);
return opr::BatchedMatrixMul::make(inputs[0], inputs[1], matmul.param());
}
OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // batched_matrix_mul

namespace { namespace dot {
auto apply_on_var_node(
const OpDef&,
const VarNodeArray& inputs) {
mgb_assert(inputs.size() == 2);
return opr::Dot::make(inputs[0], inputs[1]);
}
OP_TRAIT_REG(Dot, Dot)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // dot

namespace { namespace argsort {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& argsort = static_cast<const Argsort&>(def);
return opr::Argsort::make(inputs[0], argsort.param());
}
OP_TRAIT_REG(Argsort, Argsort)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // argsort

namespace { namespace argmax {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& argmax = static_cast<const Argmax&>(def);
return opr::Argmax::make(inputs[0], argmax.param());
}
OP_TRAIT_REG(Argmax, Argmax)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // argmax

namespace { namespace argmin {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& argmin = static_cast<const Argmin&>(def);
return opr::Argmin::make(inputs[0], argmin.param());
}
OP_TRAIT_REG(Argmin, Argmin)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // argmin

namespace { namespace warp_perspective {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& warp = static_cast<const WarpPerspective&>(def);
if (inputs.size() == 3) {
return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], warp.param());
} else {
mgb_assert(inputs.size() == 4);
return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], inputs[3], warp.param());
}
}
OP_TRAIT_REG(WarpPerspective, WarpPerspective)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // warp_perspective

namespace { namespace group_local {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& local = static_cast<const GroupLocal&>(def);
mgb_assert(inputs.size() == 2);
return opr::GroupLocal::make(inputs[0], inputs[1], local.param());
}
OP_TRAIT_REG(GroupLocal, GroupLocal)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // group_local

namespace { namespace indexing_one_hot {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const IndexingOneHot&>(def);
mgb_assert(inputs.size() == 2);
return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param());
}
OP_TRAIT_REG(IndexingOneHot, IndexingOneHot)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // indexing_one_hot

namespace { namespace indexing_set_one_hot {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const IndexingSetOneHot&>(def);
mgb_assert(inputs.size() == 3);
return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2], op.param());
}
OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // indexing_set_one_hot

namespace { namespace typecvt {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const TypeCvt&>(def);
mgb_assert(inputs.size() == 1);
return opr::TypeCvt::make(inputs[0], op.dtype);
}
OP_TRAIT_REG(TypeCvt, TypeCvt)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // typecvt

namespace { namespace concat {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const Concat&>(def);
cg::OperatorNodeConfig config{op.comp_node};
return opr::Concat::make(inputs, op.axis, config);
}
OP_TRAIT_REG(Concat, Concat)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // concat

namespace { namespace copy {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const Copy&>(def);
mgb_assert(inputs.size() == 1);
cg::OperatorNodeConfig config{op.comp_node};
return opr::Copy::make(inputs[0], config);
}
OP_TRAIT_REG(Copy, Copy)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // copy

namespace { namespace identity {
auto apply_on_var_node(
const OpDef&,
const VarNodeArray& inputs) {
mgb_assert(inputs.size() == 1);
return opr::Identity::make(inputs[0]);
}
OP_TRAIT_REG(Identity, Identity)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // identity

namespace { namespace uniform_rng {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const UniformRNG&>(def);
mgb_assert(inputs.size() == 1);
return opr::UniformRNG::make(inputs[0], op.param());
}
OP_TRAIT_REG(UniformRNG, UniformRNG)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // uniform_rng

namespace { namespace gaussian_rng {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const GaussianRNG&>(def);
mgb_assert(inputs.size() == 1);
return opr::GaussianRNG::make(inputs[0], op.param());
}
OP_TRAIT_REG(GaussianRNG, GaussianRNG)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // gaussian_rng

namespace { namespace roi_align {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const ROIAlign&>(def);
mgb_assert(inputs.size() == 2);
return opr::ROIAlign::make(inputs[0], inputs[1], op.param());
}
OP_TRAIT_REG(ROIAlign, ROIAlign)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // roi_align

#if MGB_CUDA
namespace { namespace nvof {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const NvOf&>(def);
mgb_assert(inputs.size() == 1);
return opr::NvOf::make(inputs[0], op.param());
}
OP_TRAIT_REG(NvOf, NvOf)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // nvof
#endif

namespace { namespace linspace {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const Linspace&>(def);
mgb_assert(inputs.size() == 3);
cg::OperatorNodeConfig config{op.comp_node};
return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(), config);
}
OP_TRAIT_REG(Linspace, Linspace)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // linspace

namespace { namespace eye {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const Eye&>(def);
mgb_assert(inputs.size() == 1);
cg::OperatorNodeConfig config{op.comp_node};
opr::Eye::Param param{op.k, op.dtype.enumv()};
return opr::Eye::make(inputs[0], param, config);
}
OP_TRAIT_REG(Eye, Eye)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // eye

namespace { namespace roi_pooling {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const ROIPooling&>(def);
mgb_assert(inputs.size() == 3);
return opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param());
}
OP_TRAIT_REG(ROIPooling, ROIPooling)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // roi_pooling

namespace { namespace remap {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const Remap&>(def);
mgb_assert(inputs.size() == 2);
return opr::Remap::make(inputs[0], inputs[1], op.param());
}
OP_TRAIT_REG(Remap, Remap)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // remap

namespace { namespace reshape {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const Reshape&>(def);
mgb_assert(inputs.size() == 2);
return opr::Reshape::make(inputs[0], inputs[1], op.param());
}
OP_TRAIT_REG(Reshape, Reshape)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // reshape

namespace {
auto get_index(
const VarNodeArray& inputs, size_t vidx,
const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& mask) {
size_t length = mask.size();
opr::Subtensor::IndexDesc ret(length);
for (size_t i = 0; i < length; ++ i) {
auto&& [axis, begin, end, step, idx] = mask[i];
ret[i].axis = axis;
if (idx) {
ret[i].idx = inputs[vidx++];
} else {
mgb_assert(begin || end || step);
if (begin) ret[i].begin = inputs[vidx++];
if (end) ret[i].end = inputs[vidx++];
if (step) ret[i].step = inputs[vidx++];
}
}
mgb_assert(vidx == inputs.size());
return ret;
}
#define IN1 inputs[0]
#define IN2 inputs[0], inputs[1]

#define FANCY_INDEXING_IMPL(NAME, NR_INPUT) \
namespace NAME##_impl { \
auto apply_on_var_node( \
const OpDef& def, \
const VarNodeArray& inputs) { \
auto&& op = static_cast<const NAME&>(def); \
return opr::NAME::make(IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items)); \
} \
OP_TRAIT_REG(NAME, NAME) \
.apply_on_var_node(apply_on_var_node) \
.fallback(); \
}

FANCY_INDEXING_IMPL(Subtensor, 1)
FANCY_INDEXING_IMPL(SetSubtensor, 2)
FANCY_INDEXING_IMPL(IncrSubtensor, 2)
FANCY_INDEXING_IMPL(IndexingMultiAxisVec, 1)
FANCY_INDEXING_IMPL(IndexingSetMultiAxisVec, 2)
FANCY_INDEXING_IMPL(IndexingIncrMultiAxisVec, 2)
FANCY_INDEXING_IMPL(MeshIndexing, 1)
FANCY_INDEXING_IMPL(IncrMeshIndexing, 2)
FANCY_INDEXING_IMPL(SetMeshIndexing, 2)
FANCY_INDEXING_IMPL(BatchedMeshIndexing, 1)
FANCY_INDEXING_IMPL(BatchedIncrMeshIndexing, 2)
FANCY_INDEXING_IMPL(BatchedSetMeshIndexing, 2)

#undef FANCY_INDEXING_IMPL
#undef IN1
#undef IN2
} // anonymous namespace

namespace { namespace fake_quant {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const FakeQuant&>(def);
mgb_assert(inputs.size() == 3);
return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param());
}
OP_TRAIT_REG(FakeQuant, FakeQuant)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // fake_quant
namespace { namespace elemwise_multi_type {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const ElemwiseMultiType&>(def);
OperatorNodeConfig config{op.dtype};
return opr::ElemwiseMultiType::make(inputs, op.param(), config);
}
OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // fake_quant

namespace { namespace svd {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const SVD&>(def);
mgb_assert(inputs.size() == 1);
return opr::SVD::make(inputs[0], op.param());
}
OP_TRAIT_REG(SVD, SVD)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // svd

} // namespace mgb::imperative

+ 1
- 5
imperative/src/impl/ops/tensor_manip.cpp View File

@@ -9,7 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "megbrain/imperative/ops/tensor_manip.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/opr/tensor_manip.h"
#include "../op_trait.h"
@@ -140,8 +140,4 @@ OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat)
.fallback();
} // namespace

MGB_DYN_TYPE_OBJ_FINAL_IMPL(GetVarShape);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ParamPackSplit);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ParamPackConcat);

} // namespace mgb::imperative

+ 7
- 7
imperative/src/impl/profiler.cpp View File

@@ -130,7 +130,7 @@ void Profiler::start(uint32_t flags) {
// TODO: assign parent
entry.parent = 0;
// Record apply context and save to m_profile
entry.op = def.copy();
entry.op = const_cast<OpDef&>(def).shared_from_this();
for (auto&& input : inputs) {
entry.inputs.push_back({m_tensor_recorder.record_tensor(input),
shape2vector(input->layout()),
@@ -172,31 +172,31 @@ void Profiler::start(uint32_t flags) {
if (flags & PROFILE_FOOTPRINT) {
hook_apply_on_var_node->apply_hook(
[this](auto&& apply, const OpDef& def,
VarNodeArray inputs) -> cg::OperatorNodeBase* {
auto* operator_node = apply(def, std::move(inputs));
VarNodeArray inputs) -> VarNodeArray {
auto vars = apply(def, std::move(inputs));
std::remove_reference_t<decltype(m_entry_stack.top())>
top;
{
MGB_LOCK_GUARD(m_lock);
if (m_entry_stack.empty()) {
return operator_node;
return vars;
}
top = m_entry_stack.top();
}
auto [current_op, current_entry, thread_id] = top;
if (current_op != &def ||
thread_id != std::this_thread::get_id()) {
return operator_node;
return vars;
}
auto&& footprint_result =
footprint.calc_footprint(operator_node);
footprint.calc_footprint(vars[0]->owner_opr());
current_entry->memory = footprint_result.memory;
current_entry->computation =
footprint_result.computation;
#if MGB_ENABLE_JSON
current_entry->param = footprint_result.param;
#endif
return operator_node;
return vars;
});
}
m_hooker_list.push_back(std::move(hook_apply_on_physical_tensor));


+ 3
- 3
imperative/src/impl/proxy_graph.cpp View File

@@ -590,7 +590,7 @@ cg::OperatorNodeBase* ProxyGraph::get_proxy_opr(
for (size_t i = 0; i < inputs.size(); ++ i) {
vinputs[i] = InputPlaceholder::make(*m_graph, *inputs[i]).node();
}
auto opr = OpDef::apply_on_var_node(opdef, vinputs);
auto opr = OpDef::apply_on_var_node(opdef, vinputs)[0]->owner_opr();
mgb_assert(!opr->same_type<InputPlaceholder>());
for (auto &&i : opr->input()) {
mgb_assert(i->owner_opr()->same_type<InputPlaceholder>());
@@ -639,7 +639,7 @@ ProxyGraph::make_backward_graph(
return ret.first->second;
};
auto inputs = make_input_place_holders(input_descs);
auto fwd = OpDef::apply_on_var_node(opdef, inputs);
auto fwd = OpDef::apply_on_var_node(opdef, inputs)[0]->owner_opr();
auto&& outputs = fwd->usable_output();
SmallVector<LogicalTensorDesc> output_descs;
for (auto&& i : outputs) {
@@ -799,7 +799,7 @@ cg::OperatorNodeBase* ProxyGraph::get_proxy_opr(const OpDef& opdef,
const SmallVector<LogicalTensorDesc>& inputs) {
mgb_assert(!m_cur_opr);
auto vinputs = make_input_place_holders(inputs);
return OpDef::apply_on_var_node(opdef, vinputs);
return OpDef::apply_on_var_node(opdef, vinputs)[0]->owner_opr();
}

VarNodeArray ProxyGraph::make_input_place_holders(const SmallVector<LogicalTensorDesc>& inputs) {


+ 7
- 16
imperative/src/include/megbrain/imperative/op_def.h View File

@@ -26,13 +26,12 @@ struct BackwardGraphResult {
std::vector<bool> input_has_grad;
};

class OpDef : public Hashable {
class OpDef : public Hashable,
public std::enable_shared_from_this<OpDef> {
mutable const OpTrait* m_trait = nullptr;
public:
virtual ~OpDef() = default;

virtual std::shared_ptr<OpDef> copy() const = 0;

static std::shared_ptr<OpDef> make_from_op_node(
cg::OperatorNodeBase* node);

@@ -40,7 +39,7 @@ public:
const OpDef& def,
const SmallVector<TensorPtr>& inputs);

static cg::OperatorNodeBase* apply_on_var_node(
static cg::VarNodeArray apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs);

@@ -56,25 +55,17 @@ public:

const OpTrait* trait() const;

virtual size_t hash() const {
mgb_throw(MegBrainError, "not implemented");
}
virtual size_t hash() const;

virtual bool is_same_st(const Hashable&) const {
mgb_throw(MegBrainError, "not implemented");
}
virtual bool is_same_st(const Hashable&) const;
};

template<typename T>
class OpDefImplBase : public OpDef {
public:
virtual std::shared_ptr<OpDef> copy() const override {
return std::shared_ptr<OpDef>(new T(this->cast_final_safe<T>()));
}

template<typename ...Args>
static std::shared_ptr<OpDef> make(const Args& ...args) {
return std::shared_ptr<OpDef>(new T(args...));
static std::shared_ptr<OpDef> make(Args&& ...args) {
return std::make_shared<T>(std::forward<Args>(args)...);
}
};



imperative/src/include/megbrain/imperative/ops/cond_take.h → imperative/src/include/megbrain/imperative/ops/autogen.h View File

@@ -1,5 +1,5 @@
/**
* \file imperative/src/include/megbrain/imperative/ops/cond_take.h
* \file imperative/src/include/megbrain/imperative/ops/autogen.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
@@ -12,22 +12,15 @@
#pragma once

#include "megbrain/imperative/op_def.h"
#include "megdnn/opr_param_defs.h"
#include "megbrain/opr/param_defs.h"

namespace mgb::imperative {

class CondTake : public OpDefImplBase<CondTake> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
CondTake() = default;
#include "megbrain/utils/hash.h"

size_t hash() const override {
return reinterpret_cast<std::uintptr_t>(dyn_typeinfo());
}

bool is_same_st(const Hashable& rhs) const override {
return rhs.dyn_typeinfo() == dyn_typeinfo();
}
namespace mgb::imperative {

};
// TODO: split into separate files to avoid recompiling all
// impl/ops/*.cpp on each modification of ops.td
#include "./opdef.h.inl"

} // namespace mgb::imperative

+ 0
- 70
imperative/src/include/megbrain/imperative/ops/batch_norm.h View File

@@ -1,70 +0,0 @@
/**
* \file imperative/src/include/megbrain/imperative/ops/batch_norm.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/imperative/op_def.h"
#include "megbrain/utils/hash.h"

namespace mgb::imperative {

class BatchNorm : public OpDefImplBase<BatchNorm> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using Param = opr::BatchNorm::Param;

Param::ParamDim param_dim;
Param::FwdMode fwd_mode;
double epsilon;
double avg_factor;
float scale;
float bias;

BatchNorm() = default;
BatchNorm(const Param::ParamDim& param_dim_, const Param::FwdMode& fwd_mode_,
double epsilon_, double avg_factor_, float scale_, float bias_)
: param_dim(param_dim_),
fwd_mode(fwd_mode_),
epsilon(epsilon_),
avg_factor(avg_factor_),
scale(scale_),
bias(bias_) {}

size_t hash() const override {
XXHash xxhash{};
auto append = [&xxhash](auto field){
auto hash_val = HashTrait<decltype(field)>::eval(field);
xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val));
};
append(param_dim);
append(fwd_mode);
append(epsilon);
append(avg_factor);
append(scale);
append(bias);
return xxhash.digest();
}

bool is_same_st(const Hashable& rhs_) const override {
auto&& rhs = static_cast<const BatchNorm&>(rhs_);
return rhs.param_dim == param_dim
&& rhs.fwd_mode == fwd_mode
&& rhs.epsilon == epsilon
&& rhs.avg_factor == avg_factor
&& rhs.scale == scale
&& rhs.bias == bias;
}

};

} // namespace mgb::imperative

+ 0
- 35
imperative/src/include/megbrain/imperative/ops/broadcast.h View File

@@ -1,35 +0,0 @@
/**
* \file imperative/src/include/megbrain/imperative/ops/broadcast.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include "megbrain/opr/tensor_manip.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/op_def.h"

namespace mgb::imperative {

class Broadcast : public OpDefImplBase<Broadcast> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
Broadcast() = default;

size_t hash() const override {
return reinterpret_cast<std::uintptr_t>(dyn_typeinfo());
}

bool is_same_st(const Hashable& rhs) const override {
return true;
}

};

} // namespace mgb::imperative

+ 0
- 69
imperative/src/include/megbrain/imperative/ops/collective_comm.h View File

@@ -1,69 +0,0 @@
/**
* \file imperative/src/include/megbrain/imperative/ops/collective_comm.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include "megbrain/imperative/op_def.h"
#include "megbrain/opr/param_defs.h"

namespace mgb {
namespace imperative {

class CollectiveComm : public OpDefImplBase<CollectiveComm> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;

public:
using Mode = megdnn::param::CollectiveComm::Mode;

CollectiveComm() = default;
CollectiveComm(const std::string& key_, size_t nr_devices_,
uint32_t rank_, bool is_root_, bool local_grad_,
const std::string& addr_, uint32_t port_,
const Mode& mode_,
const DType& dtype_, const std::string& backend_,
const std::string& comp_node_)
: key(key_),
nr_devices(nr_devices_),
rank(rank_),
is_root(is_root_),
local_grad(local_grad_),
addr(addr_),
port(port_),
mode(mode_),
dtype(dtype_),
backend(backend_),
comp_node(comp_node_) {}
std::string key;
size_t nr_devices;
uint32_t rank;
bool is_root;
bool local_grad;
std::string addr;
uint32_t port;
Mode mode;
DType dtype;
std::string backend;
std::string comp_node;

size_t hash() const override;

bool is_same_st(const Hashable& another) const override;
auto as_tuple() const{
return std::tuple(key, nr_devices, rank, is_root,
local_grad, addr, port, mode, dtype,
backend, comp_node);
}
};

} // namespace imperative
} // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 0
- 42
imperative/src/include/megbrain/imperative/ops/elemwise.h View File

@@ -1,42 +0,0 @@
/**
* \file imperative/src/include/megbrain/imperative/ops/elemwise.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include "megbrain/opr/basic_arith.h"
#include "megbrain/imperative/op_def.h"

namespace mgb::imperative {

class Elemwise : public OpDefImplBase<Elemwise> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using Mode = opr::Elemwise::Mode;
using ModeTrait = megdnn::Elemwise::ModeTrait;

Mode mode;

Elemwise() = default;
Elemwise(const Mode& mode_): mode(mode_) {}
size_t hash() const override {
return hash_pair_combine(mgb::hash(mode), reinterpret_cast<std::uintptr_t>(dyn_typeinfo()));
}

bool is_same_st(const Hashable& rhs_) const override {
auto&& rhs = static_cast<const Elemwise&>(rhs_);
return rhs.mode == mode;
}

};

} // namespace mgb::imperative

+ 0
- 77
imperative/src/include/megbrain/imperative/ops/io_remote.h View File

@@ -1,77 +0,0 @@
/**
* \file imperative/src/include/megbrain/imperative/ops/io_remote.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include "megbrain/imperative/op_def.h"

namespace mgb {
namespace imperative {

class RemoteSend : public OpDefImplBase<RemoteSend> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;

public:
RemoteSend() = default;
RemoteSend(const std::string& key_, const std::string& addr_,
uint32_t port_, uint32_t rank_to_)
: key(key_),
addr(addr_),
port(port_),
rank_to(rank_to_) {}
std::string key;
std::string addr;
uint32_t port;
uint32_t rank_to;

size_t hash() const override;
bool is_same_st(const Hashable& another) const override;

auto as_tuple() const{
return std::tuple(key, addr, port, rank_to);
}
};

class RemoteRecv : public OpDefImplBase<RemoteRecv> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;

public:
RemoteRecv() = default;
RemoteRecv(const std::string& key_, const std::string& addr_,
uint32_t port_, uint32_t rank_from_, TensorShape shape_,
CompNode cn_, const DType& dtype_)
: key(key_),
addr(addr_),
port(port_),
rank_from(rank_from_),
cn(cn_),
shape(shape_),
dtype(dtype_) {}
std::string key;
std::string addr;
uint32_t port;
uint32_t rank_from;
CompNode cn;
TensorShape shape;
DType dtype;

size_t hash() const override;
bool is_same_st(const Hashable& another) const override;

auto as_tuple() const{
return std::tuple(key, addr, port, rank_from, cn, dtype, shape.to_string());
}
};

} // namespace imperative
} // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 0
- 41
imperative/src/include/megbrain/imperative/ops/nms.h View File

@@ -1,41 +0,0 @@
/**
* \file imperative/src/include/megbrain/imperative/ops/nms.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include "megbrain/imperative/op_def.h"

namespace mgb::imperative {

class NMSKeep : public OpDefImplBase<NMSKeep> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
float iou_thresh; //!< IoU threshold for overlapping
uint32_t max_output; //!< max number of output boxes per batch
NMSKeep() = default;
NMSKeep(float iou_thresh_, uint32_t max_output_):
iou_thresh(iou_thresh_), max_output(max_output_) {}

size_t hash() const override {
return hash_pair_combine(
hash_pair_combine(mgb::hash(iou_thresh), mgb::hash(max_output)),
reinterpret_cast<std::uintptr_t>(dyn_typeinfo()));
}

bool is_same_st(const Hashable& rhs_) const override {
auto&& rhs = static_cast<const NMSKeep&>(rhs_);
return rhs.iou_thresh == iou_thresh
&& rhs.max_output == max_output;
}

};

} // namespace mgb::imperative

+ 0
- 99
imperative/src/include/megbrain/imperative/ops/tensor_manip.h View File

@@ -1,99 +0,0 @@
/**
* \file imperative/src/include/megbrain/imperative/ops/tensor_manip.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include "megbrain/imperative/op_def.h"

#include "megbrain/utils/hash.h"

namespace mgb::imperative {

class GetVarShape : public OpDefImplBase<GetVarShape> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
GetVarShape() = default;

size_t hash() const override {
return reinterpret_cast<std::uintptr_t>(dyn_typeinfo());
}

bool is_same_st(const Hashable& rhs) const override {
return rhs.dyn_typeinfo() == dyn_typeinfo();
}
};

class ParamPackSplit : public OpDefImplBase<ParamPackSplit> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;

public:
ParamPackSplit() = default;

ParamPackSplit(std::vector<dt_int32>& offsets_,
std::vector<std::vector<size_t>>& shapes_)
: offsets(offsets_), shapes(shapes_) {}

std::vector<dt_int32> offsets;
std::vector<std::vector<size_t>> shapes;

size_t hash() const override {
XXHash builder;
for (auto&& offset : offsets) {
builder.update(&offset, sizeof(offset));
}
auto&& offset_cnt = offsets.size();
builder.update(&offset_cnt, sizeof(offset_cnt));
for (auto&& shape : shapes) {
for (auto&& dim_len : shape) {
builder.update(&dim_len, sizeof(dim_len));
}
auto&& dim_cnt = shape.size();
builder.update(&dim_cnt, sizeof(dim_cnt));
}
auto&& shape_cnt = shapes.size();
builder.update(&shape_cnt, sizeof(shape_cnt));
return builder.digest();
}

bool is_same_st(const Hashable& rhs) const override {
auto&& pps = rhs.cast_final_safe<ParamPackSplit>();
return offsets == pps.offsets && shapes == pps.shapes;
}
};

class ParamPackConcat : public OpDefImplBase<ParamPackConcat> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;

public:
ParamPackConcat() = default;

ParamPackConcat(std::vector<dt_int32>& offsets_)
: offsets(offsets_) {}

std::vector<dt_int32> offsets;

size_t hash() const override {
XXHash builder;
for (auto&& offset : offsets) {
builder.update(&offset, sizeof(offset));
}
auto&& offset_cnt = offsets.size();
builder.update(&offset_cnt, sizeof(offset_cnt));
return builder.digest();
}

bool is_same_st(const Hashable& rhs) const override {
auto&& ppc = rhs.cast_final_safe<ParamPackConcat>();
return offsets == ppc.offsets;
}
};

} // namespace mgb::imperative

+ 8
- 8
imperative/src/test/backward_graph.cpp View File

@@ -29,18 +29,18 @@ TEST(TestImperative, BackwardGraphBasic) {

using Param = opr::Elemwise::Param;
Param param{Param::Mode::MUL};
OprAttr attr{"Elemwise", {}, {}};
attr.param.write_pod(param);
auto attr = OprAttr::make("Elemwise");
attr->cast_final_safe<OprAttr>().param.write_pod(param);

SmallVector<LogicalTensorDesc> input_descs;
for (auto&& i : inputs) {
input_descs.push_back({i->layout(), i->comp_node()});
}
auto result = OpDef::make_backward_graph(attr, input_descs, {true, true}, {true});
auto result = OpDef::make_backward_graph(*attr, input_descs, {true, true}, {true});
auto&& save_for_backward = result.save_for_backward;
auto&& input_has_grad = result.input_has_grad;

auto outputs = OpDef::apply_on_physical_tensor(attr, inputs);
auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs);
inputs.push_back(outputs[0]);
hvs.push_back(*gen({42}));
inputs.push_back(Tensor::make(hvs.back()));
@@ -82,16 +82,16 @@ TEST(TestImperative, BackwardGraphIdentity) {
SmallVector<TensorPtr> inputs;
inputs.push_back(a);

OprAttr attr{"Identity", {}, {}};
attr.param.write_pod<megdnn::param::Empty>({});
auto attr = OprAttr::make("Identity");
attr->cast_final_safe<OprAttr>().param.write_pod<megdnn::param::Empty>({});

SmallVector<LogicalTensorDesc> input_descs;
input_descs.push_back({a->layout(), a->comp_node()});
auto result = OpDef::make_backward_graph(attr, input_descs, {true}, {true});
auto result = OpDef::make_backward_graph(*attr, input_descs, {true}, {true});
auto&& save_for_backward = result.save_for_backward;
auto&& input_has_grad = result.input_has_grad;

auto outputs = OpDef::apply_on_physical_tensor(attr, inputs);
auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs);
inputs.push_back(outputs[0]);
inputs.push_back(dc);
mgb_assert(save_for_backward.size() == inputs.size());


+ 6
- 5
imperative/src/test/collective_comm.cpp View File

@@ -10,7 +10,7 @@
*/

#include "./helper.h"
#include "megbrain/imperative/ops/collective_comm.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/mm_handler.h"

using namespace mgb;
@@ -32,12 +32,13 @@ TEST(TestImperative, AllReduceBasic) {
}

auto run = [&](std::shared_ptr<HostTensorND> hnd, uint32_t idx) {
imperative::CollectiveComm
def{"all_reduce", 2, idx, idx==0, false, server_addr, port,
auto def =
imperative::CollectiveComm::make(
megdnn::param::CollectiveComm::Mode::ALL_REDUCE_SUM,
dtype::Float32(), "nccl", ""};
"all_reduce", 2, idx, idx==0, false, server_addr, port,
dtype::Float32(), "nccl", "");
auto inp = Tensor::make(*hnd);
auto oup = OpDef::apply_on_physical_tensor(def, {inp});
auto oup = OpDef::apply_on_physical_tensor(*def, {inp});
HostTensorND host_v;
host_v.copy_from(oup[0]->dev_tensor()).sync();
MGB_ASSERT_TENSOR_NEAR(*expect, host_v, 1e-6);


+ 1
- 1
imperative/src/test/cond_take.cpp View File

@@ -10,7 +10,7 @@
*/

#include "./helper.h"
#include "megbrain/imperative/ops/cond_take.h"
#include "megbrain/imperative/ops/autogen.h"

using namespace mgb;
using namespace imperative;


+ 1
- 1
imperative/src/test/helper.cpp View File

@@ -119,7 +119,7 @@ void OprChecker::run(std::vector<InputSpec> inp_keys) {
}, inp_keys[i]);
sym_inp[i] = opr::SharedDeviceTensor::make(*graph, host_inp[i]).node();
}
auto sym_oup = OpDef::apply_on_var_node(*m_op, sym_inp)->usable_output();
auto sym_oup = OpDef::apply_on_var_node(*m_op, sym_inp);
size_t nr_oups = sym_oup.size();
ComputingGraph::OutputSpec oup_spec(nr_oups);
SmallVector<HostTensorND> host_sym_oup(nr_oups);


+ 9
- 14
imperative/src/test/io_remote.cpp View File

@@ -10,7 +10,7 @@
*/

#include "./helper.h"
#include "megbrain/imperative/ops/io_remote.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/mm_handler.h"

using namespace mgb;
@@ -33,24 +33,19 @@ TEST(TestImperative, IORemote) {
}

auto run_send = [&](std::shared_ptr<HostTensorND> hnd) {
imperative::RemoteSend def{"io_remote_test", server_addr, port, 1};
auto def = imperative::RemoteSend::make(
"io_remote_test", server_addr, port, 1);
auto inp = Tensor::make(*hnd);
auto oup = OpDef::apply_on_physical_tensor(def, {inp});
auto oup = OpDef::apply_on_physical_tensor(*def, {inp});
};

auto run_recv = [&](std::shared_ptr<HostTensorND> hnd) {
// auto&& shape = std::initializer_list{vector_size};
imperative::RemoteRecv def{"io_remote_test",
server_addr,
port,
0,
{
vector_size,
},
CompNode::load("gpu1"),
dtype::Float32()};
auto def = imperative::RemoteRecv::make(
"io_remote_test", server_addr, port, 0,
CompNode::load("gpu1"), TensorShape{vector_size},
dtype::Float32());
auto inp = Tensor::make(*hnd);
auto oup = OpDef::apply_on_physical_tensor(def, {inp});
auto oup = OpDef::apply_on_physical_tensor(*def, {inp});
HostTensorND host_v;
host_v.copy_from(oup[0]->dev_tensor()).sync();
MGB_ASSERT_TENSOR_NEAR(*expect, host_v, 1e-6);


+ 14
- 0
imperative/tablegen/CMakeLists.txt View File

@@ -0,0 +1,14 @@
# mgb tablegen executable
set(TABLE_TARGET mgb-mlir-autogen)
add_executable(${TABLE_TARGET} autogen.cpp)
target_include_directories(${TABLE_TARGET} PRIVATE ${MLIR_LLVM_INCLUDE_DIR})
target_link_libraries(${TABLE_TARGET} PRIVATE LLVMTableGen MLIRTableGen LLVMSupport)
set(MGB_TABLEGEN_EXE ${TABLE_TARGET})

# generate megbrain opdef c header and python bindings
set(LLVM_TARGET_DEFINITIONS ${MGE_IR_DIR}/ops.td)
tablegen(MGB opdef.h.inl ${MGE_IR_INCLUDE_DIRS} "--gen-cpp-header")
tablegen(MGB opdef.cpp.inl ${MGE_IR_INCLUDE_DIRS} "--gen-cpp-body")
tablegen(MGB opdef.py.inl ${MGE_IR_INCLUDE_DIRS} "--gen-python-binding")
add_custom_target(mgb_opdef ALL DEPENDS opdef.h.inl opdef.cpp.inl opdef.py.inl param_defs_tblgen)
set(MGB_OPDEF_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR} PARENT_SCOPE)

+ 383
- 0
imperative/tablegen/autogen.cpp View File

@@ -0,0 +1,383 @@
#include <iostream>
#include <unordered_map>
#include <functional>

#include "./helper.h"

using llvm::raw_ostream;
using llvm::RecordKeeper;

enum ActionType {
None,
CppHeader,
CppBody,
Pybind
};

// NOLINTNEXTLINE
llvm::cl::opt<ActionType> action(
llvm::cl::desc("Action to perform:"),
llvm::cl::values(clEnumValN(CppHeader, "gen-cpp-header",
"Generate operator cpp header"),
clEnumValN(CppBody, "gen-cpp-body",
"Generate operator cpp body"),
clEnumValN(Pybind, "gen-python-binding",
"Generate pybind11 python bindings")));

using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase;
using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin;
using MgbHashableAttr = mlir::tblgen::MgbHashableAttrMixin;
using MgbAliasAttr = mlir::tblgen::MgbAliasAttrMixin;
using MgbOp = mlir::tblgen::MgbOpBase;
using MgbHashableOp = mlir::tblgen::MgbHashableOpMixin;

llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) {
// Note: we have already registered the corresponding attr wrappers
// for following basic ctypes so we needn't handle them here
/* auto&& attr_type_name = attr.getAttrDefName();
if (attr_type_name == "UI32Attr") {
return "uint32_t";
}
if (attr_type_name == "UI64Attr") {
return "uint64_t";
}
if (attr_type_name == "I32Attr") {
return "int32_t";
}
if (attr_type_name == "F32Attr") {
return "float";
}
if (attr_type_name == "F64Attr") {
return "double";
}
if (attr_type_name == "StrAttr") {
return "std::string";
}
if (attr_type_name == "BoolAttr") {
return "bool";
}*/

auto&& attr = llvm::cast<MgbAttrWrapper>(attr_);
if (auto e = llvm::dyn_cast<MgbEnumAttr>(&attr)) {
return e->getEnumName();
}
return attr.getUnderlyingType();
}

static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) {
os << formatv(
"class {0} : public OpDefImplBase<{0}> {{\n"
" MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n"
"public:\n",
op.getCppClassName()
);
// handle enum alias
for (auto &&i : op.getMgbAttributes()) {
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
os << formatv(
" using {0} = {1};\n",
attr->getEnumName(), attr->getUnderlyingType()
);
}
}
for (auto &&i : op.getMgbAttributes()) {
auto defaultValue = i.attr.getDefaultValue().str();
if (!defaultValue.empty()) {
defaultValue = formatv(" = {0}", defaultValue);
}
os << formatv(
" {0} {1}{2};\n",
attr_to_ctype(i.attr), i.name, defaultValue
);
}

auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) {
os << formatv(
" {0}({1}){2}{3}\n",
op.getCppClassName(), paramList, memInitList, body
);
};

gen_ctor("", "", " = default;");

if (!op.getMgbAttributes().empty()) {
std::vector<std::string> paramList, initList;
for (auto &&i : op.getMgbAttributes()) {
paramList.push_back(formatv(
"{0} {1}_", attr_to_ctype(i.attr), i.name
));
initList.push_back(formatv(
"{0}({0}_)", i.name
));
}
gen_ctor(llvm::join(paramList, ", "),
": " + llvm::join(initList, ", "),
" {}");
}

auto packedParams = op.getPackedParams();
if (!packedParams.empty()) {
std::vector<std::string> paramList, initList;
for (auto &&p : packedParams) {
auto&& paramFields = p.getFields();
auto&& paramType = p.getFullName();
auto&& paramName = formatv("packed_param_{0}", paramList.size());
paramList.push_back(
paramFields.empty() ? paramType.str()
: formatv("{0} {1}", paramType, paramName)
);
for (auto&& i : paramFields) {
initList.push_back(formatv(
"{0}({1}.{0})", i.name, paramName
));
}
}
for (auto&& i : op.getExtraArguments()) {
paramList.push_back(formatv(
"{0} {1}_", attr_to_ctype(i.attr), i.name
));
initList.push_back(formatv(
"{0}({0}_)", i.name
));
}
gen_ctor(llvm::join(paramList, ", "),
initList.empty() ? "" : ": " + llvm::join(initList, ", "),
" {}");
}

if (!packedParams.empty()) {
for (auto&& p : packedParams) {
auto accessor = p.getAccessor();
if (!accessor.empty()) {
os << formatv(
" {0} {1}() const {{\n",
p.getFullName(), accessor
);
std::vector<llvm::StringRef> fields;
for (auto&& i : p.getFields()) {
fields.push_back(i.name);
}
os << formatv(
" return {{{0}};\n",
llvm::join(fields, ", ")
);
os << " }\n";
}
}
}

if (auto decl = op.getExtraOpdefDecl()) {
os << decl.getValue();
}

os << formatv(
"};\n\n"
);
}

static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) {
auto&& className = op.getCppClassName();
os << formatv(
"MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className
);
auto formatMethImpl = [&](auto&& meth) {
return formatv(
"{0}_{1}_impl", className, meth
);
};
std::vector<std::string> methods;
if (auto hashable = llvm::dyn_cast<MgbHashableOp>(&op)) {
os << "namespace {\n";

// generate hash()
mlir::tblgen::FmtContext ctx;
os << formatv(
"size_t {0}(const OpDef& def_) {{\n",
formatMethImpl("hash")
);
os << formatv(
" auto op_ = def_.cast_final_safe<{0}>();\n"
" static_cast<void>(op_);\n",
className
);
ctx.withSelf("op_");
os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx);
os << "}\n";

// generate is_same_st()
os << formatv(
"bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n",
formatMethImpl("is_same_st")
);
os << formatv(
" auto a_ = lhs_.cast_final_safe<{0}>(),\n"
" b_ = rhs_.cast_final_safe<{0}>();\n"
" static_cast<void>(a_);\n"
" static_cast<void>(b_);\n",
className
);
os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_");
os << "}\n";

os << "} // anonymous namespace\n";

methods.push_back("hash");
methods.push_back("is_same_st");
}
if (!methods.empty()) {
os << formatv(
"OP_TRAIT_REG({0}, {0})", op.getCppClassName()
);
for (auto&& i : methods) {
os << formatv(
"\n .{0}({1})", i, formatMethImpl(i)
);
}
os << ";\n\n";
}
}

struct PybindContext {
std::unordered_map<unsigned int, std::string> enumAlias;
};

static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, PybindContext& ctx) {
auto class_name = op.getCppClassName();
os << formatv(
"py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n",
class_name
);
for (auto&& i : op.getMgbAttributes()) {
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
unsigned int enumID;
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
auto&& aliasBase = alias->getAliasBase();
enumID =
llvm::cast<MgbEnumAttr>(aliasBase)
.getBaseRecord()->getID();
} else {
enumID = attr->getBaseRecord()->getID();
}
auto&& enumAlias = ctx.enumAlias;
auto&& iter = enumAlias.find(enumID);
if (iter == enumAlias.end()) {
os << formatv(
"py::enum_<{0}::{1}>({0}Inst, \"{1}\")",
class_name, attr->getEnumName()
);
std::vector<std::string> body;
for (auto&& i: attr->getEnumMembers()) {
os << formatv(
"\n .value(\"{2}\", {0}::{1}::{2})",
class_name, attr->getEnumName(), i
);
body.push_back(formatv(
"if (str == \"{2}\") return {0}::{1}::{2};",
class_name, attr->getEnumName(), i
));
}
os << formatv(
"\n .def(py::init([](const std::string& in) {"
"\n auto&& str = normalize_enum(in);"
"\n {0}"
"\n throw py::cast_error(\"invalid enum value \" + in);"
"\n }));\n",
llvm::join(body, "\n ")
);
os << formatv(
"py::implicitly_convertible<std::string, {0}::{1}>();\n\n",
class_name, attr->getEnumName()
);
enumAlias.emplace(enumID, formatv(
"{0}Inst.attr(\"{1}\")", class_name, attr->getEnumName()
));
} else {
os << formatv(
"{0}Inst.attr(\"{1}\") = {2};\n\n",
class_name, attr->getEnumName(), iter->second
);
}
}
}
// generate op class binding
os << formatv("{0}Inst", class_name);
bool hasDefaultCtor = op.getMgbAttributes().empty();
if (!hasDefaultCtor) {
os << "\n .def(py::init<";
std::vector<llvm::StringRef> targs;
for (auto &&i : op.getMgbAttributes()) {
targs.push_back(i.attr.getReturnType());
}
os << llvm::join(targs, ", ");
os << ">()";
for (auto &&i : op.getMgbAttributes()) {
os << formatv(", py::arg(\"{0}\")", i.name);
auto defaultValue = i.attr.getDefaultValue();
if (!defaultValue.empty()) {
os << formatv(" = {0}", defaultValue);
} else {
hasDefaultCtor = true;
}
}
os << ")";
}
if (hasDefaultCtor) {
os << "\n .def(py::init<>())";
}
for (auto &&i : op.getMgbAttributes()) {
os << formatv(
"\n .def_readwrite(\"{0}\", &{1}::{0})",
i.name, class_name
);
}
os << ";\n\n";
}

static void for_each_operator(raw_ostream &os, RecordKeeper &keeper,
std::function<void(raw_ostream&, MgbOp&)> callback) {
auto op_base_class = keeper.getClass("Op");
ASSERT(op_base_class, "could not find base class Op");
for (auto&& i: keeper.getDefs()) {
auto&& r = i.second;
if (r->isSubClassOf(op_base_class)) {
auto op = mlir::tblgen::Operator(r.get());
if (op.getDialectName().str() == "mgb") {
std::cerr << "\033[34;15m" << "Generating " << r->getName().str() << "\033[0m" << std::endl;
callback(os, llvm::cast<MgbOp>(op));
}
}
}
}

static bool gen_op_def_c_header(raw_ostream &os, RecordKeeper &keeper) {
for_each_operator(os, keeper, gen_op_def_c_header_single);
return false;
}

static bool gen_op_def_c_body(raw_ostream &os, RecordKeeper &keeper) {
for_each_operator(os, keeper, gen_op_def_c_body_single);
return false;
}

static bool gen_op_def_pybind11(raw_ostream &os, RecordKeeper &keeper) {
PybindContext ctx;
using namespace std::placeholders;
for_each_operator(os, keeper,
std::bind(gen_op_def_pybind11_single, _1, _2, std::ref(ctx)));
return false;
}

int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv);
llvm::cl::ParseCommandLineOptions(argc, argv);
if (action == ActionType::CppHeader) {
return TableGenMain(argv[0], &gen_op_def_c_header);
}
if (action == ActionType::CppBody) {
return TableGenMain(argv[0], &gen_op_def_c_body);
}
if (action == ActionType::Pybind) {
return TableGenMain(argv[0], &gen_op_def_pybind11);
}
return -1;
}

+ 228
- 0
imperative/tablegen/helper.h View File

@@ -0,0 +1,228 @@
#include <string>
#include <vector>

#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Main.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/Operator.h"

using llvm::formatv;
using llvm::StringRef;
using llvm::Record;

#define ASSERT(stmt, msg) \
if (!(stmt)) { \
std::cerr << "\033[1;31m" \
<< "tablegen autogen abort due to: " << msg \
<< "\033[0m" << std::endl; \
exit(1); \
}

namespace mlir {
namespace tblgen {
template<typename ConcreteType>
struct MgbInterface : public ConcreteType {
MgbInterface() = delete;
MgbInterface(const MgbInterface&) = delete;
MgbInterface(MgbInterface&&) = delete;
~MgbInterface() = delete;
};

struct MgbAttrWrapperBase : public MgbInterface<Attribute> {
private:
struct RecordVisitor : public MgbInterface<Constraint> {
public:
static bool classof(const Constraint*) {
return true;
}

const llvm::Record* getDef() const {
return def;
}
};
public:
static bool classof(const Attribute* attr) {
return attr->isSubClassOf("MgbAttrWrapperBase");
}

const llvm::Record* getBaseRecord() const {
auto baseAttr = getBaseAttr();
return llvm::cast<RecordVisitor>(baseAttr).getDef();
}
llvm::StringRef getUnderlyingType() const {
return def->getValueAsString("underlyingType");
}
};

struct MgbEnumAttrMixin : public MgbAttrWrapperBase {
static bool classof(const Attribute* attr) {
return attr->getBaseAttr().isSubClassOf("MgbEnumAttrMixin");
}

llvm::StringRef getParentNamespace() const {
return getBaseRecord()->getValueAsString("parentNamespce");
}
llvm::StringRef getEnumName() const {
return getBaseRecord()->getValueAsString("enumName");
}
std::vector<StringRef> getEnumMembers() const {
return getBaseRecord()->getValueAsListOfStrings("enumMembers");
}
};

struct MgbHashableAttrMixin : public MgbAttrWrapperBase {
static bool classof(const Attribute* attr) {
return attr->getBaseAttr().isSubClassOf("MgbHashableAttrMixin");
}

llvm::StringRef getHashFunctionTemplate() const {
return getBaseRecord()->getValueAsString("hashFunction");
}
llvm::StringRef getCmpFunctionTemplate() const {
return getBaseRecord()->getValueAsString("cmpFunction");
}
};

struct MgbAliasAttrMixin : public MgbAttrWrapperBase {
static bool classof(const Attribute* attr) {
return attr->getBaseAttr().isSubClassOf("MgbAliasAttrMixin");
}

Attribute getAliasBase() const {
return Attribute(getBaseRecord()->getValueAsDef("aliasBase"));
}
};

class MgbPackedParam {
public:
MgbPackedParam(Record* def_): def(def_) {
auto&& dag = def->getValueAsDag("fields");
for (size_t i = 0; i < dag->getNumArgs(); ++ i) {
fields.push_back({
dag->getArgNameStr(i),
Attribute(llvm::cast<llvm::DefInit>(dag->getArg(i)))
});
}
}

llvm::StringRef getFullName() const {
return def->getValueAsString("fullName");
}
std::vector<NamedAttribute> getFields() const {
return fields;
}
llvm::StringRef getAccessor() const {
return def->getValueAsString("paramAccessor");
}
private:
std::vector<NamedAttribute> fields;
Record* def;
};

struct MgbOpBase : public MgbInterface<Operator> {
static bool isPackedParam(Record* def) {
return def->isSubClassOf("MgbPackedParamBase");
}

public:
static bool classof(const Operator* op) {
return op->getDef().isSubClassOf("MgbOp");
}

std::vector<NamedAttribute> getMgbAttributes() const {
std::vector<NamedAttribute> ret;
for (auto&& i: getAttributes()) {
if (isa<MgbAttrWrapperBase>(i.attr)) {
ret.push_back(i);
}
}
return ret;
}
std::vector<NamedAttribute> getExtraArguments() const {
std::vector<NamedAttribute> ret;
auto&& dag = getDef().getValueAsDag("extraArguments");
for (size_t i = 0; i < dag->getNumArgs(); ++ i) {
ret.push_back({
dag->getArgNameStr(i),
Attribute(llvm::cast<llvm::DefInit>(dag->getArg(i)))
});
}
return ret;
}
llvm::Optional<StringRef> getExtraOpdefDecl() const {
return getDef().getValueAsOptionalString("extraOpdefDecl");
}
std::vector<MgbPackedParam> getPackedParams() const {
std::vector<MgbPackedParam> ret;
for (auto&& i : getDef().getValueAsListOfDefs("dnnParams")) {
if (isPackedParam(i)) {
ret.emplace_back(i);
}
}
return ret;
}
};

struct MgbHashableOpMixin : public MgbOpBase {
private:
std::string getDefaultHashFunction() const {
std::string body = " size_t val = mgb::hash($_self.dyn_typeinfo());\n";
if (!getMgbAttributes().empty()) {
auto getHashFunc = [&](auto&& iter) {
auto&& attr = llvm::cast<MgbHashableAttrMixin>(iter.attr);
return attr.getHashFunctionTemplate();
};
mlir::tblgen::FmtContext ctx;
for (auto&& it: getMgbAttributes()) {
body += formatv(
" val = mgb::hash_pair_combine(val, {0});\n",
mlir::tblgen::tgfmt(getHashFunc(it), &ctx, "$_self." + it.name)
);
}
}
body += " return val;\n";
return body;
}
std::string getDefaultCmpFunction() const {
std::string body;
if (!getMgbAttributes().empty()) {
mlir::tblgen::FmtContext ctx;
for (auto&& it : getMgbAttributes()) {
auto&& attr = llvm::cast<MgbHashableAttrMixin>(it.attr);
body += formatv(
" if ({0}) return false;\n",
mlir::tblgen::tgfmt(attr.getCmpFunctionTemplate(),
&ctx, "$0." + it.name, "$1." + it.name)
);
}
}
body += " return true;\n";
return body;
}
public:
static bool classof(const Operator* op) {
return op->getDef().isSubClassOf("MgbHashableOpMixin");
}

std::string getHashFunctionTemplate() const {
if (auto f = getDef().getValueAsOptionalString("hashFunction")) {
return f.getValue().str();
}
return getDefaultHashFunction();
}
std::string getCmpFunctionTemplate() const {
if (auto f = getDef().getValueAsOptionalString("cmpFunction")) {
return f.getValue().str();
}
return getDefaultCmpFunction();
}
};

} // namespace tblgen
} // namespace mlir

+ 1
- 1
imperative/test/CMakeLists.txt View File

@@ -11,7 +11,7 @@ endif()

# TODO: turn python binding into a static/object library
add_executable(imperative_test ${SOURCES} ${SRCS})
target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include)
target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR})

# Python binding
target_include_directories(imperative_test PRIVATE ${MODULE_SRC_INCLUDE} ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR})


+ 257
- 0
src/core/include/megbrain/ir/base.td View File

@@ -0,0 +1,257 @@
/**
* \file src/core/include/megbrain/ir/base.td
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#ifndef MGB_BASE
#define MGB_BASE

include "mlir/IR/OpBase.td"

def Mgb_Dialect : Dialect {
let name = "mgb";
let cppNamespace = "mgb::dialect";
}

// -- mgb Attr mixin
class MgbAttrWrapperBase<string className> {
string underlyingType = className;
int recursionDepth = 0;
}

class MgbHashableAttrMixin {
string hashFunction = "mgb::hash($0)";
// return 0 for eq, else for ne
string cmpFunction = "$0 != $1";
}

class MgbEnumAttrMixin<string namespace, string name, list<string> members> {
string parentNamespace = namespace;
string enumName = name;
list<string> enumMembers = members;
}

class MgbAttrWrapper;
class MgbAliasAttrMixin<Attr base> {
Attr aliasBase = base;
}

// -- mgb custom Attr
// TODO: CPred and description
class MgbAttrWrapper<string className>:
Attr<CPred<"true">, "TODO">, MgbAttrWrapperBase<className> {
let returnType = underlyingType;
}

class HashableAttr<string className>:
MgbAttrWrapper<className>, MgbHashableAttrMixin;

// -- basic types
class MgbIntegerAttrBase<string CType> : HashableAttr<CType> {
let storageType = "::mlir::IntegerAttr";
}

class MgbSignlessIntegerAttrBase<string CType> : MgbIntegerAttrBase<CType> {
let convertFromStorage = "static_cast<" # underlyingType # ">($_self.getInt())";
let constBuilderCall = "$_builder.getIntegerAttr($_builder.getIntegerType(sizeof(" # underlyingType # ") * 4), $0)";
}

class MgbSignedIntegerAttrBase<string CType> : MgbIntegerAttrBase<CType> {
let convertFromStorage = "static_cast<" # underlyingType # ">($_self.getSInt())";
let constBuilderCall = "$_builder.getIntegerAttr($_builder.getIntegerType(sizeof(" # underlyingType # ") * 4, true), $0)";
}

class MgbUnsignedIntegerAttrBase<string CType> : MgbIntegerAttrBase<CType> {
let convertFromStorage = "static_cast<" # underlyingType # ">($_self.getUInt())";
let constBuilderCall = "$_builder.getIntegerAttr($_builder.getIntegerType(sizeof(" # underlyingType # ") * 4, false), $0)";
}

def MgbI8Attr: MgbSignlessIntegerAttrBase<"int8_t">;
def MgbI32Attr: MgbSignlessIntegerAttrBase<"int32_t">;
def MgbI64Attr: MgbSignlessIntegerAttrBase<"int64_t">;
def MgbUI32Attr: MgbUnsignedIntegerAttrBase<"uint32_t">;
def MgbUI64Attr: MgbUnsignedIntegerAttrBase<"uint64_t">;
def MgbSizeTAddr: MgbUnsignedIntegerAttrBase<"size_t">;

class MgbFloatAttrBase<string CType, string DType> : HashableAttr<CType> {
let storageType = "::mlir::FloatAttr";
let convertFromStorage = "static_cast<" # underlyingType # ">($_self.getValueAsDouble())";
let constBuilderCall = "$_builder.getFloatAttr($_builder.get" # DType # "Type(), $0)";
}

def MgbF32Attr : MgbFloatAttrBase<"float", "F32">;
def MgbF64Attr : MgbFloatAttrBase<"double", "F64">;

def MgbBoolAttr : HashableAttr<"bool"> {
let storageType = "::mlir::BoolAttr";
let constBuilderCall = "$_builder.getBoolAttr($0)";
}

def MgbStringAttr : HashableAttr<"std::string"> {
let storageType = "::mlir::StringAttr";
let convertFromStorage = "$_self.getValue().str()";
let constBuilderCall = "$_builder.getStringAttr($0)"; // llvm::StringRef implicit ctor
}

class MgbArrayAttr<MgbAttrWrapper elem>:
HashableAttr<"std::vector<" # elem.underlyingType # ">"> {
let storageType = "::mlir::ArrayAttr";
let recursionDepth = !add(elem.recursionDepth, 1);
let convertFromStorage =
"[&] {\n"
" " # underlyingType # " ret" # recursionDepth # ";\n"
" std::for_each($_self.begin(), $_self.end(), [&](auto&& i" # recursionDepth # ") {\n"
" ret" # recursionDepth # ".push_back(\n"
" " # !subst("$_self", "i" # recursionDepth # ".template cast<" # elem.storageType # ">()", "" # elem.convertFromStorage) # "\n"
" );\n"
" });\n"
" return ret" # recursionDepth # ";}()";
let constBuilderCall =
"[&] {\n"
" std::vector<mlir::Attribute> ret" # recursionDepth # ";\n"
" std::for_each($0.begin(), $0.end(), [&](auto&& i" # recursionDepth # ") {\n"
" ret" # recursionDepth # ".push_back(\n"
" " # !subst("$0", "i" # recursionDepth, "" # elem.constBuilderCall) # "\n"
" );\n"
" });\n"
" return $_builder.getArrayAttr(ret" # recursionDepth # ");"
"}()";
}

defvar EmptyStrList = !listsplat("", 0);
class StrListAppend<list<string> l, string s> {
list<string> r = !listconcat(l, !listsplat(s, 1));
}

class TupleConvertFromStorage<MgbAttrWrapper attr, int idx> {
string r = !subst(
"$_self",
"$_self[" # !cast<string>(idx) # "].template cast<"# attr.storageType #">()",
"" # attr.convertFromStorage);
}

class TupleConstBuilderCall<MgbAttrWrapper attr, int idx> {
string r = !subst(
"$0",
"std::get<" # !cast<string>(idx) # ">($0)",
"" # attr.constBuilderCall);
}

class ApplyTupleConvertFromStorage<list<MgbAttrWrapper> args> {
list<string> r = !foldl(
EmptyStrList, args, l, arg, StrListAppend<l, TupleConvertFromStorage<arg, !size(l)>.r>.r);
}

class ApplyTupleConstBuilderCall<list<MgbAttrWrapper> args> {
list<string> r = !foldl(
EmptyStrList, args, l, arg, StrListAppend<l, TupleConstBuilderCall<arg, !size(l)>.r>.r);
}

class MgbTupleAttr<list<MgbAttrWrapper> args>:
HashableAttr<"std::tuple<" # StrJoin<!foreach(i, args, i.underlyingType)>.result # ">"> {
let storageType = "::mlir::ArrayAttr";
let convertFromStorage = "std::make_tuple(" # StrJoin<ApplyTupleConvertFromStorage<args>.r>.result # ")";
let constBuilderCall = "$_builder.getArrayAttr({" # StrJoin<ApplyTupleConstBuilderCall<args>.r>.result # "})";
}

// -- enum types
class MgbEnumAttr<string namespace, string enumName, list<string> members>:
HashableAttr<namespace # "::" # enumName>, MgbEnumAttrMixin<namespace, enumName, members> {
let storageType = "::mlir::IntegerAttr";
let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())";
let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))";
let hashFunction = "mgb::enumhash()($0)";
}

class MgbEnumAliasAttr<string namespace, string enumName, MgbEnumAttr base>:
MgbEnumAttr<namespace, enumName, base.enumMembers>, MgbAliasAttrMixin<base>;

// -- other types
def MgbDTypeAttr: HashableAttr<"::megdnn::DType"> {
let storageType = "::mlir::IntegerAttr";
let convertFromStorage = underlyingType # "::from_enum(static_cast<::megdnn::DTypeEnum>($_self.getInt()))";
let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0.enumv()))";
let hashFunction = "mgb::hash($0.handle())";
}

def MgbCompNodeAttr: HashableAttr<"::mgb::CompNode"> {
let storageType = "::mlir::StringAttr";
let convertFromStorage = underlyingType # "::load($_self.getValue().str())";
let constBuilderCall = "$_builder.getStringAttr($0.to_string_logical())";
}

def MgbTensorShapeAttr: HashableAttr<"::megdnn::TensorShape"> {
let storageType = "::mlir::ArrayAttr";
let hashFunction = "mgb::PODHash<size_t>::perform($0.shape, $0.ndim)";
let cmpFunction = "!$0.eq_shape($1)";
defvar elemInst = MgbSizeTAddr;
let convertFromStorage =
"[&] {\n"
" " # underlyingType # " ret;\n"
" std::for_each($_self.begin(), $_self.end(), [&ret](auto&& i) {\n"
" ret[ret.ndim ++] = " # !subst("$_self", "i.template cast<"# elemInst.storageType #">()", "" # elemInst.convertFromStorage) # ";\n"
" });\n"
" return ret;}()";
let constBuilderCall =
"[&] {\n"
" std::vector<mlir::Attribute> ret;\n"
" for (size_t i = 0; i < $0.ndim; ++ i) {\n"
" ret.push_back(\n"
" " # !subst("$0", "$0[i]", "" # elemInst.constBuilderCall) # "\n"
" );\n"
" }\n"
" return $_builder.getArrayAttr(ret);"
"}()";
}

class MgbDefaultValuedAttr<MgbAttrWrapper attr, string value>:
DefaultValuedAttr<attr, value>, MgbAttrWrapperBase<attr.underlyingType> {
// Note: this class is similar to DefaultValuedAttr but with extra
// meta informations which are used by mgb dialect tblgen, so this
// has to be kept up to date with class MgbAttrWrapperMixin
let recursionDepth = attr.recursionDepth;
}

// -- dnn params
class MgbParamBase<string className> {
string paramType = className;
string fullName = "::megdnn::param::" # paramType;
dag fields = ?;
}

class MgbPackedParamBase<string className, string accessor>:
MgbParamBase<className> {
string paramAccessor = accessor;
}

// -- mgb ops
class MgbHashableOpMixin {
string hashFunction = ?;
string cmpFunction = ?;
}

class MgbOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits=[]>:
Op<Mgb_Dialect, mnemonic, traits> {
dag inputs = (ins);
dag extraArguments = (ins);
// TODO: remove it
code extraOpdefDecl = ?;

let arguments = !con(
!foldl(inputs, params, args, param, !con(args, param.fields)),
extraArguments);

list<MgbParamBase> dnnParams = params;
}

class MgbHashableOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits=[]>:
MgbOp<mnemonic, params, traits>, MgbHashableOpMixin;

#endif // MGB_BASE

+ 240
- 0
src/core/include/megbrain/ir/ops.td View File

@@ -0,0 +1,240 @@
/**
* \file src/core/include/megbrain/ir/ops.td
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#ifndef MGB_OPS
#define MGB_OPS

include "base.td"
include "param_defs.td"

include "mlir/Interfaces/SideEffectInterfaces.td"

def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> {
let inputs = (ins Variadic<AnyType>:$input);
let results = (outs AnyType);
}

def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>;

def TypeCvt: MgbHashableOp<"TypeCvt", [], [NoSideEffect]> {
let inputs = (ins AnyType:$inputs);
let extraArguments = (ins
TypeAttr:$idtype,
MgbDTypeAttr:$dtype
);
let results = (outs AnyType);
}

def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam]>;

def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam]>;

def Dot: MgbHashableOp<"Dot", [EmptyParam]>;

def SVD: MgbHashableOp<"SVD", [SVDParam]>;

def Convolution : MgbHashableOp<"Convolution", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>;

def ConvolutionBackwardData: MgbHashableOp<"ConvolutionBackwardData", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>;

def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>;

def Pooling: MgbHashableOp<"Pooling", [PoolingParam]>;

def AdaptivePooling : MgbHashableOp<"AdaptivePooling", [AdaptivePoolingParam]>;

def ROIPooling: MgbHashableOp<"ROIPooling", [ROIPoolingParam]>;

def ConvBias : MgbHashableOp<"ConvBias", [ConvBiasParam, ExecutionPolicyParamBase<"policy">]> {
let extraArguments = (ins
MgbDTypeAttr:$dtype
);
}

def BatchConvBias : MgbHashableOp<"BatchConvBias", [BatchConvBiasParam, ExecutionPolicyParamBase<"policy">]> {
let extraArguments = (ins
MgbDTypeAttr:$dtype
);
}

def BatchNorm : MgbHashableOp<"BatchNorm", [BNParam]>;

def ROIAlign: MgbHashableOp<"ROIAlign", [ROIAlignParam]>;

def WarpPerspective: MgbHashableOp<"WarpPerspective", [WarpPerspectiveParam]>;

def Remap: MgbHashableOp<"Remap", [RemapParam]>;

def IndexingOneHot: MgbHashableOp<"IndexingOneHot", [AxisParam]>;

def IndexingSetOneHot: MgbHashableOp<"IndexingSetOneHot", [AxisParam]>;

def Copy: MgbHashableOp<"Copy"> {
let extraArguments = (ins
MgbCompNodeAttr:$comp_node
);
}

def Argsort: MgbHashableOp<"Argsort", [ArgsortParam]>;

def Argmax : MgbHashableOp<"Argmax", [AxisParam]>;

def Argmin : MgbHashableOp<"Argmin", [AxisParam]>;

def CondTake : MgbHashableOp<"CondTake">;

def TopK: MgbHashableOp<"TopK", [TopKParam]>;

def NvOf: MgbHashableOp<"NvOf", [NvOfParam]>;

def UniformRNG: MgbHashableOp<"UniformRNG", [UniformRNGParam]> {
let hashFunction = [{return mgb::hash($_self.dyn_typeinfo());}];
let cmpFunction = [{return true;}];
}

def GaussianRNG: MgbHashableOp<"GaussianRNG", [GaussianRNGParam]> {
let hashFunction = [{
return mgb::hash_pair_combine(
mgb::hash($_self.dyn_typeinfo()),
mgb::hash_pair_combine(mgb::hash($_self.mean), mgb::hash($_self.std)));
}];
let cmpFunction = [{return $0.mean == $1.mean && $0.std == $1.std;}];
}

def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> {
let extraArguments = (ins
MgbCompNodeAttr:$comp_node
);
}

def Eye: MgbHashableOp<"Eye", [EyeParam]> {
let extraArguments = (ins
MgbCompNodeAttr:$comp_node
);
}

def GetVarShape : MgbHashableOp<"GetVarShape">;

def Concat: MgbHashableOp<"Concat", [AxisParam]> {
let extraArguments = (ins
MgbCompNodeAttr:$comp_node
);
}

def Broadcast : MgbHashableOp<"Broadcast", [EmptyParam]>;

def Identity: MgbHashableOp<"Identity">;

def CollectiveComm : MgbHashableOp<"CollectiveComm", [CollectiveCommParam]> {
let extraArguments = (ins
MgbStringAttr:$key,
MgbUI32Attr:$nr_devices,
MgbUI32Attr:$rank,
MgbBoolAttr:$is_root,
MgbBoolAttr:$local_grad,
MgbStringAttr:$addr,
MgbUI32Attr:$port,
MgbDTypeAttr:$dtype,
MgbStringAttr:$backend,
MgbStringAttr:$comp_node
);
}

def RemoteSend : MgbHashableOp<"RemoteSend"> {
let extraArguments = (ins
MgbStringAttr:$key,
MgbStringAttr:$addr,
MgbUI32Attr:$port,
MgbUI32Attr:$rank_to
);
}

def RemoteRecv : MgbHashableOp<"RemoteRecv"> {
let extraArguments = (ins
MgbStringAttr:$key,
MgbStringAttr:$addr,
MgbUI32Attr:$port,
MgbUI32Attr:$rank_from,
MgbCompNodeAttr:$cn,
MgbTensorShapeAttr:$shape,
MgbDTypeAttr:$dtype
);
}

def NMSKeep : MgbHashableOp<"NMSKeep"> {
let extraArguments = (ins
MgbF32Attr:$iou_thresh,
MgbUI32Attr:$max_output
);
}

def ParamPackSplit : MgbHashableOp<"ParamPackSplit"> {
let extraArguments = (ins
MgbArrayAttr<MgbI32Attr>:$offsets,
MgbArrayAttr<MgbArrayAttr<MgbSizeTAddr>>:$shapes
);
}

def ParamPackConcat : MgbHashableOp<"ParamPackConcat"> {
let extraArguments = (ins
MgbArrayAttr<MgbI32Attr>:$offsets
);
}

def Dimshuffle: MgbHashableOp<"Dimshuffle"> {
let inputs = (ins AnyMemRef:$input);
let extraArguments = (ins MgbArrayAttr<MgbI32Attr>:$pattern);
let results = (outs AnyMemRef);
}

def Reshape: MgbHashableOp<"Reshape", [OptionalAxisV1Param]>;

// TODO: merge Add/Remove Axis into AxisAddRemove as megbrain?
def AddAxis: MgbHashableOp<"AddAxis"> {
let extraArguments = (ins
MgbArrayAttr<MgbI32Attr>:$axis
);
}
def RemoveAxis: MgbHashableOp<"RemoveAxis"> {
let extraArguments = (ins
MgbArrayAttr<MgbI32Attr>:$axis
);
}

class FancyIndexingBase<string name>: MgbHashableOp<name> {
let extraArguments = (ins
MgbArrayAttr<MgbTupleAttr<
[MgbI8Attr, MgbBoolAttr, MgbBoolAttr, MgbBoolAttr, MgbBoolAttr]>>:$items
);
}

def Subtensor: FancyIndexingBase<"Subtensor">;
def SetSubtensor: FancyIndexingBase<"SetSubtensor">;
def IncrSubtensor: FancyIndexingBase<"IncrSubtensor">;
def IndexingMultiAxisVec: FancyIndexingBase<"IndexingMultiAxisVec">;
def IndexingSetMultiAxisVec: FancyIndexingBase<"IndexingSetMultiAxisVec">;
def IndexingIncrMultiAxisVec: FancyIndexingBase<"IndexingIncrMultiAxisVec">;
def MeshIndexing: FancyIndexingBase<"MeshIndexing">;
def IncrMeshIndexing: FancyIndexingBase<"IncrMeshIndexing">;
def SetMeshIndexing: FancyIndexingBase<"SetMeshIndexing">;
def BatchedMeshIndexing: FancyIndexingBase<"BatchedMeshIndexing">;
def BatchedIncrMeshIndexing: FancyIndexingBase<"BatchedIncrMeshIndexing">;
def BatchedSetMeshIndexing: FancyIndexingBase<"BatchedSetMeshIndexing">;

def FakeQuant: MgbHashableOp<"FakeQuant", [FakeQuantParam]>;
def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypeParam]> {
let extraArguments = (ins
MgbDTypeAttr:$dtype
);
}

#endif // MGB_OPS

+ 1
- 0
third_party/prepare.sh View File

@@ -47,3 +47,4 @@ pushd MegRay/third_party >/dev/null
popd >/dev/null

git submodule update --init pybind11
git submodule update --init llvm-project

Loading…
Cancel
Save