You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_flatbuffers_schema.py 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import argparse
  4. import collections
  5. import textwrap
  6. import os
  7. import hashlib
  8. import struct
  9. import io
  10. from gen_param_defs import member_defs, ParamDef, IndentWriterBase
  11. def _cname_to_fbname(cname):
  12. return {
  13. "uint32_t": "uint",
  14. "uint64_t": "ulong",
  15. "int32_t": "int",
  16. "float": "float",
  17. "double": "double",
  18. "DTypeEnum": "DTypeEnum",
  19. "bool": "bool",
  20. }[cname]
  21. def scramble_enum_member_name(name):
  22. if name in ("MIN", "MAX"):
  23. return name + "_"
  24. return name
  25. class FlatBuffersWriter(IndentWriterBase):
  26. _skip_current_param = False
  27. _last_param = None
  28. _enums = None
  29. _used_enum = None
  30. _cur_const_val = {}
  31. def __call__(self, fout, defs):
  32. param_io = io.StringIO()
  33. super().__call__(param_io)
  34. self._used_enum = set()
  35. self._enums = {}
  36. self._process(defs)
  37. super().__call__(fout)
  38. self._write("// %s", self._get_header())
  39. self._write('include "dtype.fbs";')
  40. self._write("namespace mgb.serialization.fbs.param;\n")
  41. self._write_enums()
  42. self._write(param_io.getvalue())
  43. def _write_enums(self):
  44. for (p, e) in sorted(self._used_enum):
  45. name = p + e
  46. e = self._enums[(p, e)]
  47. self._write_doc(e.name)
  48. self._write("enum %s%s : uint {", p, e.name, indent=1)
  49. for idx, member in enumerate(e.members):
  50. self._write_doc(member)
  51. if e.combined:
  52. self._write("%s=%d,", scramble_enum_member_name(str(member)),
  53. 1<<idx)
  54. else:
  55. self._write("%s,", scramble_enum_member_name(str(member)))
  56. self._write("}\n", indent=-1)
  57. def _write_doc(self, doc):
  58. if not isinstance(doc, member_defs.Doc) or not doc.doc: return
  59. doc_lines = []
  60. if doc.no_reformat:
  61. doc_lines = doc.raw_lines
  62. else:
  63. doc = doc.doc.replace('\n', ' ')
  64. text_width = 80 - len(self._cur_indent) - 4
  65. doc_lines = textwrap.wrap(doc, text_width)
  66. for line in doc_lines:
  67. self._write("/// " + line)
  68. def _on_param_begin(self, p):
  69. self._last_param = p
  70. self._cur_const_val = {}
  71. if p.is_legacy:
  72. self._skip_current_param = True
  73. return
  74. self._write_doc(p.name)
  75. self._write("table %s {", p.name, indent=1)
  76. def _on_param_end(self, p):
  77. if self._skip_current_param:
  78. self._skip_current_param = False
  79. return
  80. self._write("}\n", indent=-1)
  81. def _on_member_enum(self, e):
  82. p = self._last_param
  83. key = str(p.name), str(e.name)
  84. self._enums[key] = e
  85. if self._skip_current_param:
  86. return
  87. self._write_doc(e.name)
  88. self._used_enum.add(key)
  89. self._write("%s:%s%s = %s;", e.name_field, p.name, e.name,
  90. scramble_enum_member_name(str(e.members[e.default])))
  91. def _resolve_const(self, v):
  92. while v in self._cur_const_val:
  93. v = self._cur_const_val[v]
  94. return v
  95. def _on_member_field(self, f):
  96. if self._skip_current_param:
  97. return
  98. self._write_doc(f.name)
  99. self._write("%s:%s = %s;", f.name, _cname_to_fbname(f.dtype.cname),
  100. self._get_fb_default(self._resolve_const(f.default)))
  101. def _on_const_field(self, f):
  102. self._cur_const_val[str(f.name)] = str(f.default)
  103. def _on_member_enum_alias(self, e):
  104. if self._skip_current_param:
  105. return
  106. self._used_enum.add((e.src_class, e.src_name))
  107. enum_name = e.src_class + e.src_name
  108. self._write(
  109. "%s:%s = %s;", e.name_field, enum_name,
  110. scramble_enum_member_name(str(e.src_enum.members[e.get_default()])))
  111. def _get_fb_default(self, cppdefault):
  112. if not isinstance(cppdefault, str):
  113. return cppdefault
  114. d = cppdefault
  115. if d.endswith('f'): # 1.f
  116. return d[:-1]
  117. if d.endswith('ull'):
  118. return d[:-3]
  119. if d.startswith("DTypeEnum::"):
  120. return d[11:]
  121. return d
  122. def main():
  123. parser = argparse.ArgumentParser(
  124. 'generate FlatBuffers schema of operator param from description file')
  125. parser.add_argument('input')
  126. parser.add_argument('output')
  127. args = parser.parse_args()
  128. with open(args.input) as fin:
  129. inputs = fin.read()
  130. exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc})
  131. input_hash = hashlib.sha256()
  132. input_hash.update(inputs.encode(encoding='UTF-8'))
  133. input_hash = input_hash.hexdigest()
  134. writer = FlatBuffersWriter()
  135. with open(args.output, 'w') as fout:
  136. writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs)
  137. if __name__ == "__main__":
  138. main()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台