You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_flatbuffers_converter.py 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import argparse
  4. import collections
  5. import textwrap
  6. import os
  7. import hashlib
  8. import struct
  9. import io
  10. from gen_param_defs import member_defs, ParamDef, IndentWriterBase
  11. class ConverterWriter(IndentWriterBase):
  12. _skip_current_param = False
  13. _last_param = None
  14. _param_fields = None
  15. _fb_fields = []
  16. def __call__(self, fout, defs):
  17. super().__call__(fout)
  18. self._write("// %s", self._get_header())
  19. self._write('#include <flatbuffers/flatbuffers.h>')
  20. self._write("namespace mgb {")
  21. self._write("namespace serialization {")
  22. self._write("namespace fbs {")
  23. self._process(defs)
  24. self._write("} // namespace fbs")
  25. self._write("} // namespace serialization")
  26. self._write("} // namespace mgb")
  27. def _on_param_begin(self, p):
  28. self._last_param = p
  29. self._param_fields = []
  30. self._fb_fields = ["builder"]
  31. self._write("template<>\nstruct ParamConverter<megdnn::param::%s> {",
  32. p.name, indent=1)
  33. self._write("using MegDNNType = megdnn::param::%s;", p.name)
  34. self._write("using FlatBufferType = fbs::param::%s;\n", p.name)
  35. def _on_param_end(self, p):
  36. if self._skip_current_param:
  37. self._skip_current_param = False
  38. return
  39. self._write("static MegDNNType to_param(const FlatBufferType* fb) {",
  40. indent=1)
  41. line = 'return {'
  42. line += ', '.join(self._param_fields)
  43. line += '};'
  44. self._write(line)
  45. self._write("}\n", indent=-1)
  46. self._write(
  47. "static flatbuffers::Offset<FlatBufferType> to_flatbuffer(flatbuffers::FlatBufferBuilder& builder, const MegDNNType& param) {",
  48. indent=1)
  49. line = 'return fbs::param::Create{}('.format(str(p.name))
  50. line += ', '.join(self._fb_fields)
  51. line += ');'
  52. self._write(line)
  53. self._write('}', indent=-1)
  54. self._write("};\n", indent=-1)
  55. def _on_member_enum(self, e):
  56. p = self._last_param
  57. key = str(p.name) + str(e.name)
  58. if self._skip_current_param:
  59. return
  60. self._param_fields.append(
  61. "static_cast<megdnn::param::{}::{}>(fb->{}())".format(
  62. str(p.name), str(e.name), e.name_field))
  63. self._fb_fields.append("static_cast<fbs::param::{}>(param.{})".format(
  64. key, e.name_field))
  65. def _on_member_field(self, f):
  66. if self._skip_current_param:
  67. return
  68. if f.dtype.cname == 'DTypeEnum':
  69. self._param_fields.append(
  70. "intl::convert_dtype_to_megdnn(fb->{}())".format(f.name))
  71. self._fb_fields.append(
  72. "intl::convert_dtype_to_fbs(param.{})".format(f.name))
  73. else:
  74. self._param_fields.append("fb->{}()".format(f.name))
  75. self._fb_fields.append("param.{}".format(f.name))
  76. def _on_const_field(self, f):
  77. pass
  78. def _on_member_enum_alias(self, e):
  79. if self._skip_current_param:
  80. return
  81. enum_name = e.src_class + e.src_name
  82. self._param_fields.append(
  83. "static_cast<megdnn::param::{}::{}>(fb->{}())".format(
  84. e.src_class, e.src_name, e.name_field))
  85. self._fb_fields.append("static_cast<fbs::param::{}>(param.{})".format(
  86. enum_name, e.name_field))
  87. def main():
  88. parser = argparse.ArgumentParser(
  89. 'generate convert functions between FlatBuffers type and MegBrain type')
  90. parser.add_argument('input')
  91. parser.add_argument('output')
  92. args = parser.parse_args()
  93. with open(args.input) as fin:
  94. inputs = fin.read()
  95. exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc})
  96. input_hash = hashlib.sha256()
  97. input_hash.update(inputs.encode(encoding='UTF-8'))
  98. input_hash = input_hash.hexdigest()
  99. writer = ConverterWriter()
  100. with open(args.output, 'w') as fout:
  101. writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs)
  102. if __name__ == "__main__":
  103. main()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台