You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_tablegen.py 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import argparse
  4. import collections
  5. import textwrap
  6. import os
  7. import hashlib
  8. import struct
  9. import io
  10. from gen_param_defs import member_defs, ParamDef, IndentWriterBase
  11. # FIXME: move supportToString flag definition into the param def source file
  12. ENUM_TO_STRING_SPECIAL_RULES = [
  13. ("Elemwise", "Mode"),
  14. ("ElemwiseMultiType", "Mode")
  15. ]
  16. class ConverterWriter(IndentWriterBase):
  17. _skip_current_param = False
  18. _last_param = None
  19. _current_tparams = None
  20. _packed = None
  21. _const = None
  22. def __call__(self, fout, defs):
  23. super().__call__(fout)
  24. self._write("// %s", self._get_header())
  25. self._write("#ifndef MGB_PARAM")
  26. self._write("#define MGB_PARAM")
  27. self._process(defs)
  28. self._write("#endif // MGB_PARAM")
  29. def _ctype2attr(self, ctype, value):
  30. if ctype == 'uint32_t':
  31. return 'MgbUI32Attr', value
  32. if ctype == 'uint64_t':
  33. return 'MgbUI64Attr', value
  34. if ctype == 'int32_t':
  35. return 'MgbI32Attr', value
  36. if ctype == 'float':
  37. return 'MgbF32Attr', value
  38. if ctype == 'double':
  39. return 'MgbF64Attr', value
  40. if ctype == 'bool':
  41. return 'MgbBoolAttr', value
  42. if ctype == 'DTypeEnum':
  43. self._packed = False
  44. return 'MgbDTypeAttr', 'megdnn::DType::from_enum(megdnn::{})'.format(value)
  45. raise RuntimeError("unknown ctype")
  46. def _on_param_begin(self, p):
  47. self._last_param = p
  48. self._packed = True
  49. self._current_tparams = []
  50. self._const = set()
  51. def _on_param_end(self, p):
  52. if self._skip_current_param:
  53. self._skip_current_param = False
  54. return
  55. if self._packed:
  56. self._write("class {0}ParamBase<string accessor> : MgbPackedParamBase<\"{0}\", accessor> {{".format(p.name), indent=1)
  57. else:
  58. self._write("def {0}Param: MgbParamBase<\"{0}\"> {{".format(p.name), indent=1)
  59. self._write("let fields = (ins", indent=1)
  60. self._write(",\n{}".format(self._cur_indent).join(self._current_tparams))
  61. self._write(");", indent=-1)
  62. self._write("}\n", indent=-1)
  63. if self._packed:
  64. self._write("def {0}Param : {0}ParamBase<\"param\">;\n".format(p.name))
  65. self._current_tparams = None
  66. self._packed = None
  67. self._const = None
  68. def _wrapped_with_default_value(self, attr, default):
  69. return 'MgbDefaultValuedAttr<{}, \"{}\">'.format(attr, default)
  70. def _on_member_enum(self, e):
  71. p = self._last_param
  72. # Note: always generate llvm Record def for enum attribute even it was not
  73. # directly used by any operator, or other enum couldn't alias to this enum
  74. td_class = "{}{}".format(p.name, e.name)
  75. fullname = "::megdnn::param::{}".format(p.name)
  76. enum_def = "MgbEnumAttr<\"{}\", \"{}\", [".format(fullname, e.name)
  77. def format(v):
  78. return '\"{}\"'.format(str(v).split(' ')[0].split('=')[0])
  79. enum_def += ','.join(format(i) for i in e.members)
  80. if e.combined:
  81. enum_def += "], 1"
  82. else:
  83. enum_def += "], 0"
  84. if ENUM_TO_STRING_SPECIAL_RULES.count((p.name, e.name)):
  85. enum_def += ", 1" # whether generate ToStringTrait
  86. enum_def += ">"
  87. self._write("def {} : {};".format(td_class, enum_def))
  88. if self._skip_current_param:
  89. return
  90. # wrapped with default value
  91. if e.combined:
  92. default_val = "static_cast<{}::{}>({})".format(
  93. fullname, e.name, e.compose_combined_enum(e.default))
  94. else:
  95. default_val = "{}::{}::{}".format(
  96. fullname, e.name, str(e.members[e.default]).split(' ')[0].split('=')[0])
  97. wrapped = self._wrapped_with_default_value(td_class, default_val)
  98. self._current_tparams.append("{}:${}".format(wrapped, e.name_field))
  99. def _on_member_enum_alias(self, e):
  100. p = self._last_param
  101. if self._skip_current_param:
  102. return
  103. # write enum attr def
  104. td_class = "{}{}".format(p.name, e.name)
  105. fullname = "::megdnn::param::{}".format(p.name)
  106. base_td_class = "{}{}".format(e.src_class, e.src_name)
  107. enum_def = "MgbEnumAliasAttr<\"{}\", \"{}\", {}>".format(fullname, e.name, base_td_class)
  108. self._write("def {} : {};".format(td_class, enum_def))
  109. # wrapped with default value
  110. s = e.src_enum
  111. if s.combined:
  112. default_val = "static_cast<{}::{}>({})".format(
  113. fullname, e.name, s.compose_combined_enum(e.get_default()))
  114. else:
  115. default_val = "{}::{}::{}".format(fullname, e.name, str(
  116. s.members[e.get_default()]).split(' ')[0].split('=')[0])
  117. wrapped = self._wrapped_with_default_value(td_class, default_val)
  118. self._current_tparams.append("{}:${}".format(wrapped, e.name_field))
  119. def _on_member_field(self, f):
  120. if self._skip_current_param:
  121. return
  122. attr, value = self._ctype2attr(f.dtype.cname, str(f.default))
  123. if str(value) in self._const:
  124. value = '::megdnn::param::{}::{}'.format(self._last_param.name, value)
  125. wrapped = self._wrapped_with_default_value(attr, value)
  126. self._current_tparams.append("{}:${}".format(wrapped, f.name))
  127. def _on_const_field(self, f):
  128. self._const.add(str(f.name))
  129. def main():
  130. parser = argparse.ArgumentParser('generate op param tablegen file')
  131. parser.add_argument('input')
  132. parser.add_argument('output')
  133. args = parser.parse_args()
  134. with open(args.input) as fin:
  135. inputs = fin.read()
  136. exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc})
  137. input_hash = hashlib.sha256()
  138. input_hash.update(inputs.encode(encoding='UTF-8'))
  139. input_hash = input_hash.hexdigest()
  140. writer = ConverterWriter()
  141. with open(args.output, 'w') as fout:
  142. writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs)
  143. if __name__ == "__main__":
  144. main()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台