You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_op_defs.py 18 kB


  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import argparse
  10. import collections
  11. import textwrap
  12. import os
  13. import hashlib
  14. import struct
  15. class member_defs:
  16. """contain classes to define members of an opr param"""
  17. Dtype = collections.namedtuple('Dtype', ['cname', 'pycvt', 'pyfmt',
  18. 'cppjson', 'cname_attr'])
  19. Dtype.__new__.__defaults__ = ('', )
  20. uint32 = Dtype('uint32_t', 'int', 'I', 'NumberInt')
  21. uint64 = Dtype('uint64_t', 'int', 'Q', 'NumberInt',
  22. 'alignas(sizeof(uint64_t)) ')
  23. int32 = Dtype('int32_t', 'int', 'i', 'NumberInt')
  24. float32 = Dtype('float', 'float', 'f', 'Number')
  25. float64 = Dtype('double', 'float', 'd', 'Number')
  26. dtype = Dtype('DTypeEnum', '_as_dtype_num', 'I', 'Number')
  27. bool = Dtype('bool', 'bool', '?', 'Bool')
  28. class Base:
  29. pass
  30. class Doc:
  31. """wrap an identifier to associate document
  32. note: if the doc starts with a linebreak, it would not be reforamtted.
  33. """
  34. __slots__ = ['id', 'doc']
  35. def __init__(self, id_, doc):
  36. assert isinstance(id_, str) and isinstance(doc, str), (id_, doc)
  37. self.id = id_
  38. self.doc = doc
  39. @property
  40. def no_reformat(self):
  41. """whether reformat is disallowed for this doc string"""
  42. return self.doc.startswith('\n')
  43. @property
  44. def raw_lines(self):
  45. """the doc lines when ``no_format`` is true"""
  46. ret = self.doc.split('\n')
  47. assert not ret[0]
  48. return ret[1:]
  49. @classmethod
  50. def make(cls, v):
  51. """make doc object from str or doc"""
  52. if isinstance(v, cls):
  53. return v
  54. assert isinstance(v, str)
  55. return cls(v, '')
  56. def __str__(self):
  57. return self.id
  58. def __eq__(self, rhs):
  59. if isinstance(rhs, str):
  60. return self.id == rhs
  61. return (isinstance(rhs, Doc) and
  62. (self.id, self.doc) == (rhs.id, rhs.doc))
  63. class Enum(Base):
  64. """define an enum; the result would contain both an enum class def and its
  65. corresponding data field
  66. :param default: index of default member value
  67. :attr name_field: name of the data field of this enum in the param
  68. struct
  69. :attr member_alias: list of (member, alias) pairs
  70. """
  71. __slots__ = ['name', 'name_field', 'members', 'default',
  72. 'member_alias']
  73. all_enums = {}
  74. """(param_name, name) => enum"""
  75. def __init__(self, param_name, name, name_field, members, default,
  76. member_alias):
  77. name = member_defs.Doc.make(name)
  78. assert name.id[0].isupper()
  79. members = tuple(map(member_defs.Doc.make, members))
  80. if isinstance(default, str):
  81. if default not in name_field:
  82. raise ValueError(
  83. "Default value '{}' does not exist.".format(default))
  84. default = name_field.index(default)
  85. assert isinstance(default, int)
  86. self.name = name
  87. self.name_field = self.get_name_field(name.id, name_field)
  88. self.members = members
  89. self.default = default
  90. self.all_enums[(param_name, name.id)] = self
  91. assert isinstance(member_alias, list)
  92. self.member_alias = member_alias
  93. @classmethod
  94. def get_name_field(cls, name, name_field):
  95. if name_field is None:
  96. name_field = name[0].lower() + name[1:]
  97. assert isinstance(name_field, str)
  98. return name_field
  99. class Field(Base):
  100. """define a normal data field"""
  101. __slots__ = ['name', 'dtype', 'default']
  102. def __init__(self, name, dtype, default):
  103. assert isinstance(dtype, member_defs.Dtype)
  104. self.name = member_defs.Doc.make(name)
  105. self.dtype = dtype
  106. self.default = default
  107. class Const(Base):
  108. """define a const data field"""
  109. __slots__ = ['name', 'dtype', 'default']
  110. def __init__(self, name, dtype, default):
  111. assert isinstance(dtype, member_defs.Dtype)
  112. self.name = member_defs.Doc.make(name)
  113. self.dtype = dtype
  114. self.default = default
  115. class EnumAlias(Base):
  116. """alias of enum type from another param"""
  117. __slots__ = ['name', 'name_field', 'src_class', 'src_name', 'default']
  118. def __init__(self, name, name_field, src_class, src_name, default):
  119. self.name = name
  120. self.name_field = member_defs.Enum.get_name_field(name, name_field)
  121. self.src_class = src_class
  122. if src_name is None:
  123. src_name = name
  124. self.src_name = src_name
  125. self.default = default
  126. @property
  127. def src_enum(self):
  128. """source Enum class"""
  129. return member_defs.Enum.all_enums[(self.src_class, self.src_name)]
  130. def get_default(self):
  131. """get default index; fallback to src index if default is not
  132. set"""
  133. if self.default is None:
  134. return self.src_enum.default
  135. return self.default
  136. class ParamDef:
  137. """"""
  138. __all_tags = set()
  139. all_param_defs = []
  140. __slots__ = ['name', 'members', 'tag', 'is_legacy']
  141. def __init__(self, name, doc='', *, version=0, is_legacy=False):
  142. self.members = []
  143. self.all_param_defs.append(self)
  144. h = hashlib.sha256(name.encode('utf-8'))
  145. if version:
  146. h.update(struct.pack('<I', version))
  147. if is_legacy:
  148. name += 'V{}'.format(version)
  149. self.name = member_defs.Doc(name, doc)
  150. self.tag = int(h.hexdigest()[:8], 16)
  151. self.is_legacy = is_legacy
  152. if self.tag < 1024:
  153. self.tag += 1024
  154. assert self.tag not in self.__all_tags, (
  155. 'tag hash confliction: name={} tag={}'.format(name, self.tag))
  156. self.__all_tags.add(self.tag)
  157. def add_fields(self, dtype, *names_defaults):
  158. assert isinstance(dtype, str)
  159. dtype = getattr(member_defs, dtype)
  160. assert len(names_defaults) % 2 == 0
  161. for i, j in zip(names_defaults[::2], names_defaults[1::2]):
  162. self.members.append(member_defs.Field(i, dtype, j))
  163. return self
  164. def add_enum(self, name, *members, default=0, name_field=None,
  165. member_alias=[]):
  166. self.members.append(member_defs.Enum(
  167. self.name.id, name, name_field, members, default, member_alias))
  168. return self
  169. def add_enum_alias(self, name, src_class, src_name=None, name_field=None,
  170. default=None):
  171. self.members.append(member_defs.EnumAlias(
  172. name, name_field, src_class, src_name, default))
  173. return self
  174. def add_const(self, dtype, *names_defaults):
  175. assert isinstance(dtype, str)
  176. dtype = getattr(member_defs, dtype)
  177. assert len(names_defaults) % 2 == 0
  178. for i, j in zip(names_defaults[::2], names_defaults[1::2]):
  179. self.members.append(member_defs.Const(i, dtype, j))
  180. return self
  181. class WriterBase:
  182. """base class for output file writer"""
  183. _fout = None
  184. _input_hash = None
  185. _cur_class = None
  186. def __call__(self, fout):
  187. self._fout = fout
  188. def set_input_hash(self, h):
  189. self._input_hash = h
  190. return self
  191. def _get_header(self):
  192. return 'generated by {} for {}'.format(
  193. os.path.basename(__file__),
  194. self._input_hash
  195. )
  196. def _process(self, defs):
  197. dispatch = {
  198. member_defs.Enum: self._on_member_enum,
  199. member_defs.EnumAlias: self._on_member_enum_alias,
  200. member_defs.Field: self._on_member_field,
  201. member_defs.Const: self._on_const_field
  202. }
  203. for i in defs:
  204. assert isinstance(i, ParamDef)
  205. if i.is_legacy:
  206. continue
  207. self._cur_class = i.name
  208. self._on_param_begin(i)
  209. for j in i.members:
  210. dispatch[type(j)](j)
  211. self._on_param_end(i)
  212. def _on_param_begin(self, p):
  213. """:type p: :class:`.ParamDef`"""
  214. def _on_param_end(self, p):
  215. """:type p: :class:`.ParamDef`"""
  216. def _on_member_enum(self, e):
  217. """:type p: :class:`.Enum`"""
  218. def _on_member_enum_alias(self, e):
  219. """:type p: :class:`.EnumAlias`"""
  220. def _on_member_field(self, f):
  221. """:type p: :class:`.Field`"""
  222. def _on_const_field(self, f):
  223. """:type p: :class:`.Const`"""
  224. class IndentWriterBase(WriterBase):
  225. _cur_indent = ''
  226. def _indent(self):
  227. self._cur_indent += ' ' * 4
  228. def _unindent(self):
  229. self._cur_indent = self._cur_indent[:-4]
  230. def _write(self, content, *fmt, indent=0):
  231. if indent < 0:
  232. self._unindent()
  233. self._fout.write(self._cur_indent)
  234. if fmt:
  235. content = content % fmt
  236. self._fout.write(content)
  237. self._fout.write('\n')
  238. if indent > 0:
  239. self._indent()
  240. class PyWriter(IndentWriterBase):
  241. _static_members = None
  242. _non_static_members = None
  243. _enums = None
  244. _enum_map = None
  245. def __call__(self, fout, defs):
  246. super().__call__(fout)
  247. self._enum_map = {}
  248. self._write('// %s', self._get_header())
  249. self._write('#include "megbrain/imperative/opdef/all.h"')
  250. self._write('')
  251. self._write('using namespace mgb::imperative;')
  252. self._write('')
  253. self._process(defs)
  254. def _on_param_begin(self, p):
  255. self._enums = []
  256. self._non_static_members = []
  257. self._static_members = []
  258. def _reg_enum_single(self, cur_def, e):
  259. alias = None
  260. if isinstance(e, member_defs.Enum):
  261. src = e
  262. else:
  263. assert isinstance(e, member_defs.EnumAlias)
  264. src = e.src_enum
  265. alias = e
  266. src_py_name = self._enum_map.get(src, None)
  267. if src_py_name is not None:
  268. py_name = '{}{}Enum'.format(cur_def, src.name if alias is None else alias.name)
  269. self._write('m.attr("{}") = m.attr("{}");\n'.format(py_name, src_py_name))
  270. return
  271. if alias is None:
  272. enum_name = str(src.name)
  273. else:
  274. enum_name = str(alias.name)
  275. c_name = 'opdef::{}::{}'.format(cur_def, enum_name)
  276. py_name = '{}{}Enum'.format(cur_def, enum_name)
  277. self._write('py::enum_<{}>(m, "{}")'.format(c_name, py_name), indent=1)
  278. for i in src.members:
  279. self._write('.value("{0}", {1}::{0})'.format(i, c_name))
  280. self._write(';\n', indent=-1)
  281. self._enum_map[src] = py_name
  282. def _on_param_end(self, p):
  283. cur_def = '{}Def'.format(p.name)
  284. for e in self._enums:
  285. self._reg_enum_single(cur_def, e)
  286. self._write('py::class_<opdef::{0}>(m, "{0}")'.format(cur_def), indent=1)
  287. # TODO: use ctor with given default value
  288. self._write('.def(py::init<>())')
  289. for i in self._static_members:
  290. assert isinstance(i, member_defs.Const)
  291. self._write('.def_property_readonly_static("{0}", []() {{ return opdef::{1}::{0}; }})'.format(i.name, cur_def))
  292. for i in self._non_static_members:
  293. fname = None
  294. if isinstance(i, member_defs.Field):
  295. fname = i.name
  296. else:
  297. assert isinstance(i, (member_defs.Enum, member_defs.EnumAlias))
  298. fname = i.name_field
  299. self._write('.def_readwrite("{0}", &opdef::{1}::{0})'.format(fname, cur_def))
  300. self._write(';\n', indent=-1)
  301. def _on_member_enum(self, e,):
  302. self._enums.append(e)
  303. self._non_static_members.append(e)
  304. def _on_member_enum_alias(self, e):
  305. self._enums.append(e)
  306. self._non_static_members.append(e)
  307. def _on_member_field(self, f):
  308. self._non_static_members.append(f)
  309. def _on_const_field(self, f):
  310. self._static_members.append(f)
  311. class CPPWriter(IndentWriterBase):
  312. _param_namespace = 'opdef'
  313. _ctor_args = None
  314. """list of (text in func param, var name); func param name must be var name
  315. appended by an underscore"""
  316. _non_static_members = None
  317. def __call__(self, fout, defs):
  318. super().__call__(fout)
  319. self._write('// %s', self._get_header())
  320. self._write('#pragma once')
  321. self._write('#include "megdnn.h"')
  322. # which defined in megbrain/tools/param_defs/mgb_opr_param_defs.py
  323. self._write('#include "megbrain/opr/param_defs.h"')
  324. self._write('#include <stdint.h>')
  325. self._write('namespace mgb {')
  326. self._write('namespace imperative {')
  327. self._write('namespace %s {', self._param_namespace)
  328. self._write('namespace {')
  329. self._write('#include "megdnn/dtype.h"')
  330. self._write('using DTypeEnum = megdnn::DTypeEnum;')
  331. self._write('} // anonymous namespace')
  332. self._process(defs)
  333. self._write('} // namespace %s', self._param_namespace)
  334. self._write('} // namespace imperative')
  335. self._write('} // namespace mgb')
  336. self._write('// vim: syntax=cpp.doxygen')
  337. def _on_param_begin(self, p):
  338. self._write('struct %sDef {', p.name, indent=1)
  339. self._ctor_args = []
  340. self._non_static_members = []
  341. def _add_ctor_args(self, typename, default, varname):
  342. self._ctor_args.append((
  343. '{} {}_={}'.format(typename, varname, default),
  344. varname))
  345. def _on_param_end(self, p):
  346. '''
  347. MegDNN param structures are not packed and we need to initialize the structure
  348. paddings to zero or it would break MegBrain hash system. We do memset(0) in default
  349. ctor and use a trick, wrapping non-static members in a anonymous union which would
  350. copy the object representation in its default copy/move ctor, for copy/move ctor.
  351. > The implicitly-defined copy/move constructor for a non-union class X performs
  352. > a memberwise copy/move of its bases and members. [class.copy.ctor 14]
  353. > The implicitly-defined copy/move constructor for a union X copies the object
  354. > representation (6.9) of X. [class.copy.ctor 15]
  355. '''
  356. if self._non_static_members:
  357. self._write('union { struct {')
  358. for i in self._non_static_members:
  359. if isinstance(i, member_defs.Field):
  360. self._write('%s%s %s;', i.dtype.cname_attr, i.dtype.cname, i.name)
  361. else:
  362. assert isinstance(i, (member_defs.Enum, member_defs.EnumAlias))
  363. self._write('%s %s;', i.name, i.name_field)
  364. self._write('}; };')
  365. param_list = []
  366. if self._ctor_args:
  367. pdefs, varnames = zip(*self._ctor_args)
  368. self._write('%sDef(%s) {', p.name, ', '.join(pdefs), indent=1)
  369. self._write('memset(this, 0, sizeof(*this));')
  370. for var in varnames:
  371. self._write('this->%s = %s_;', var, var)
  372. param_list.append(str(var))
  373. self._write('}', indent=-1)
  374. self._write('megdnn::param::%s param() {', self._cur_class, indent=1)
  375. self._write('return {%s};', ','.join(param_list))
  376. self._write('}', indent=-1)
  377. self._write('};\n', indent=-1)
  378. def __on_member_enum(self, e, default_value):
  379. self._write('using %s = megdnn::param::%s::%s;', e.name, self._cur_class, e.name)
  380. self._non_static_members.append(e)
  381. self._add_ctor_args(e.name, default_value, e.name_field)
  382. def _on_member_enum(self, e,):
  383. self.__on_member_enum(e, '{}::{}'.format(e.name, e.members[e.default]))
  384. def _on_member_enum_alias(self, e):
  385. self.__on_member_enum(e, '{}::{}'.format(e.name, e.src_enum.members[e.get_default()]))
  386. def _on_member_field(self, f):
  387. self._non_static_members.append(f)
  388. self._add_ctor_args(f.dtype.cname, f.default, f.name)
  389. def _on_const_field(self, f):
  390. if 'int' in f.dtype.cname:
  391. self._write('static constexpr %s%s %s = %s;', f.dtype.cname_attr, f.dtype.cname, f.name, f.default)
  392. else:
  393. self._write('static const %s%s %s = %s;', f.dtype.cname_attr, f.dtype.cname, f.name, f.default)
  394. def main():
  395. parser = argparse.ArgumentParser(
  396. 'generate opr param defs from description file')
  397. parser.add_argument('-t', '--type', choices=['c++', 'py'], default='c++',
  398. help='output type')
  399. parser.add_argument('input')
  400. parser.add_argument('output')
  401. args = parser.parse_args()
  402. with open(args.input) as fin:
  403. inputs = fin.read()
  404. exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc})
  405. input_hash = hashlib.sha256()
  406. input_hash.update(inputs.encode(encoding='UTF-8'))
  407. input_hash = input_hash.hexdigest()
  408. if args.type == 'py':
  409. writer = PyWriter()
  410. else:
  411. writer = CPPWriter()
  412. with open(args.output, 'w') as fout:
  413. writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs)
  414. if __name__ == '__main__':
  415. main()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台