You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_param_defs.py 29 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import argparse
  4. import collections
  5. import textwrap
  6. import os
  7. import hashlib
  8. import struct
  9. class member_defs:
  10. """contain classes to define members of an opr param"""
  11. Dtype = collections.namedtuple('Dtype', ['cname', 'pycvt', 'pyfmt',
  12. 'cppjson', 'cname_attr'])
  13. Dtype.__new__.__defaults__ = ('', )
  14. uint32 = Dtype('uint32_t', 'int', 'I', 'NumberInt')
  15. uint64 = Dtype('uint64_t', 'int', 'Q', 'NumberInt',
  16. 'alignas(sizeof(uint64_t)) ')
  17. int32 = Dtype('int32_t', 'int', 'i', 'NumberInt')
  18. float32 = Dtype('float', 'float', 'f', 'Number')
  19. float64 = Dtype('double', 'float', 'd', 'Number')
  20. dtype = Dtype('DTypeEnum', '_as_dtype_num', 'I', 'Number')
  21. bool = Dtype('bool', 'bool', '?', 'Bool')
  22. class Base:
  23. pass
  24. class Doc:
  25. """wrap an identifier to associate document
  26. note: if the doc starts with a linebreak, it would not be reforamtted.
  27. """
  28. __slots__ = ['id', 'doc']
  29. def __init__(self, id_, doc):
  30. assert isinstance(id_, str) and isinstance(doc, str), (id_, doc)
  31. self.id = id_
  32. self.doc = doc
  33. @property
  34. def no_reformat(self):
  35. """whether reformat is disallowed for this doc string"""
  36. return self.doc.startswith('\n')
  37. @property
  38. def raw_lines(self):
  39. """the doc lines when ``no_format`` is true"""
  40. ret = self.doc.split('\n')
  41. assert not ret[0]
  42. return ret[1:]
  43. @classmethod
  44. def make(cls, v):
  45. """make doc object from str or doc"""
  46. if isinstance(v, cls):
  47. return v
  48. assert isinstance(v, str)
  49. return cls(v, '')
  50. def __str__(self):
  51. return self.id
  52. def __eq__(self, rhs):
  53. if isinstance(rhs, str):
  54. return self.id == rhs
  55. return (isinstance(rhs, Doc) and
  56. (self.id, self.doc) == (rhs.id, rhs.doc))
  57. class Enum(Base):
  58. """define an enum; the result would contain both an enum class def and its
  59. corresponding data field
  60. :param default: index of default member value
  61. :attr name_field: name of the data field of this enum in the param
  62. struct
  63. :attr member_alias: list of (member, alias) pairs
  64. """
  65. __slots__ = ['name', 'name_field', 'members', 'default',
  66. 'member_alias', 'combined']
  67. all_enums = {}
  68. """(param_name, name) => enum"""
  69. def __init__(self, param_name, name, name_field, members, default,
  70. member_alias, combined = False):
  71. name = member_defs.Doc.make(name)
  72. assert name.id[0].isupper()
  73. members = tuple(map(member_defs.Doc.make, members))
  74. if isinstance(default, str):
  75. if default not in name_field:
  76. raise ValueError(
  77. "Default value '{}' does not exist.".format(default))
  78. default = name_field.index(default)
  79. assert isinstance(default, int)
  80. self.name = name
  81. self.combined = combined
  82. self.name_field = self.get_name_field(name.id, name_field)
  83. self.members = members
  84. self.default = default
  85. self.all_enums[(param_name, name.id)] = self
  86. assert isinstance(member_alias, list)
  87. self.member_alias = member_alias
  88. @classmethod
  89. def get_name_field(cls, name, name_field):
  90. if name_field is None:
  91. name_field = name[0].lower() + name[1:]
  92. assert isinstance(name_field, str)
  93. return name_field
  94. class Field(Base):
  95. """define a normal data field"""
  96. __slots__ = ['name', 'dtype', 'default']
  97. def __init__(self, name, dtype, default):
  98. assert isinstance(dtype, member_defs.Dtype)
  99. self.name = member_defs.Doc.make(name)
  100. self.dtype = dtype
  101. self.default = default
  102. class Const(Base):
  103. """define a const data field"""
  104. __slots__ = ['name', 'dtype', 'default']
  105. def __init__(self, name, dtype, default):
  106. assert isinstance(dtype, member_defs.Dtype)
  107. self.name = member_defs.Doc.make(name)
  108. self.dtype = dtype
  109. self.default = default
  110. class EnumAlias(Base):
  111. """alias of enum type from another param"""
  112. __slots__ = ['name', 'name_field', 'src_class', 'src_name', 'default']
  113. def __init__(self, name, name_field, src_class, src_name, default):
  114. self.name = name
  115. self.name_field = member_defs.Enum.get_name_field(name, name_field)
  116. self.src_class = src_class
  117. if src_name is None:
  118. src_name = name
  119. self.src_name = src_name
  120. self.default = default
  121. @property
  122. def src_enum(self):
  123. """source Enum class"""
  124. return member_defs.Enum.all_enums[(self.src_class, self.src_name)]
  125. def get_default(self):
  126. """get default index; fallback to src index if default is not
  127. set"""
  128. if self.default is None:
  129. return self.src_enum.default
  130. return self.default
  131. class ParamDef:
  132. """"""
  133. __all_tags = set()
  134. all_param_defs = []
  135. __slots__ = ['name', 'members', 'tag', 'is_legacy']
  136. def __init__(self, name, doc='', *, version=0, is_legacy=False):
  137. self.members = []
  138. self.all_param_defs.append(self)
  139. h = hashlib.sha256(name.encode('utf-8'))
  140. if version:
  141. h.update(struct.pack('<I', version))
  142. if is_legacy:
  143. name += 'V{}'.format(version)
  144. self.name = member_defs.Doc(name, doc)
  145. self.tag = int(h.hexdigest()[:8], 16)
  146. self.is_legacy = is_legacy
  147. if self.tag < 1024:
  148. self.tag += 1024
  149. assert self.tag not in self.__all_tags, (
  150. 'tag hash confliction: name={} tag={}'.format(name, self.tag))
  151. self.__all_tags.add(self.tag)
  152. def add_fields(self, dtype, *names_defaults):
  153. assert isinstance(dtype, str)
  154. dtype = getattr(member_defs, dtype)
  155. assert len(names_defaults) % 2 == 0
  156. for i, j in zip(names_defaults[::2], names_defaults[1::2]):
  157. self.members.append(member_defs.Field(i, dtype, j))
  158. return self
  159. def add_enum(self, name, *members, default=0, name_field=None,
  160. member_alias=[]):
  161. self.members.append(member_defs.Enum(
  162. self.name.id, name, name_field, members, default, member_alias))
  163. return self
  164. def add_bit_combination_enum(self, name, *members, default=0,
  165. name_field=None, member_alias=[]):
  166. self.members.append(member_defs.Enum(
  167. self.name.id, name, name_field, members, default, member_alias, True))
  168. return self
  169. def add_enum_alias(self, name, src_class, src_name=None, name_field=None,
  170. default=None):
  171. self.members.append(member_defs.EnumAlias(
  172. name, name_field, src_class, src_name, default))
  173. return self
  174. def add_const(self, dtype, *names_defaults):
  175. assert isinstance(dtype, str)
  176. dtype = getattr(member_defs, dtype)
  177. assert len(names_defaults) % 2 == 0
  178. for i, j in zip(names_defaults[::2], names_defaults[1::2]):
  179. self.members.append(member_defs.Const(i, dtype, j))
  180. return self
  181. class WriterBase:
  182. """base class for output file writer"""
  183. _fout = None
  184. _input_hash = None
  185. def __call__(self, fout):
  186. self._fout = fout
  187. def set_input_hash(self, h):
  188. self._input_hash = h
  189. return self
  190. def _get_header(self):
  191. return 'generated by {} for {}'.format(
  192. os.path.basename(__file__),
  193. self._input_hash
  194. )
  195. def _process(self, defs):
  196. dispatch = {
  197. member_defs.Enum: self._on_member_enum,
  198. member_defs.EnumAlias: self._on_member_enum_alias,
  199. member_defs.Field: self._on_member_field,
  200. member_defs.Const: self._on_const_field
  201. }
  202. for i in defs:
  203. assert isinstance(i, ParamDef)
  204. self._on_param_begin(i)
  205. for j in i.members:
  206. dispatch[type(j)](j)
  207. self._on_param_end(i)
  208. def _on_param_begin(self, p):
  209. """:type p: :class:`.ParamDef`"""
  210. def _on_param_end(self, p):
  211. """:type p: :class:`.ParamDef`"""
  212. def _on_member_enum(self, e):
  213. """:type p: :class:`.Enum`"""
  214. def _on_member_enum_alias(self, e):
  215. """:type p: :class:`.EnumAlias`"""
  216. def _on_member_field(self, f):
  217. """:type p: :class:`.Field`"""
  218. def _on_const_field(self, f):
  219. """:type p: :class:`.Const`"""
  220. class IndentWriterBase(WriterBase):
  221. _cur_indent = ''
  222. def _indent(self):
  223. self._cur_indent += ' ' * 4
  224. def _unindent(self):
  225. self._cur_indent = self._cur_indent[:-4]
  226. def _write(self, content, *fmt, indent=0):
  227. if indent < 0:
  228. self._unindent()
  229. self._fout.write(self._cur_indent)
  230. if fmt:
  231. content = content % fmt
  232. self._fout.write(content)
  233. self._fout.write('\n')
  234. if indent > 0:
  235. self._indent()
  236. class PyWriter(IndentWriterBase):
  237. FieldDef = collections.namedtuple(
  238. 'FieldDef', ['name', 'cvt', 'fmt', 'default', 'type', 'doc'])
  239. # see _on_param_end() for the use of those fields
  240. _cur_param_name = None
  241. _cur_fields = None
  242. _cur_struct_fmt = None
  243. _enum_member2num = None
  244. def __init__(self, for_imperative=False):
  245. self._imperative = for_imperative
  246. def __call__(self, fout, defs):
  247. super().__call__(fout)
  248. self._enum_member2num = []
  249. self._write('# %s', self._get_header())
  250. self._write('import struct')
  251. self._write('from . import enum36 as enum')
  252. self._write(
  253. 'class _ParamDefBase:\n'
  254. ' def serialize(self):\n'
  255. ' tag = struct.pack("I", type(self).TAG)\n'
  256. ' pdata = [getattr(self, i) for i in self.__slots__]\n'
  257. ' for idx, v in enumerate(pdata):\n'
  258. ' if isinstance(v, _EnumBase):\n'
  259. ' pdata[idx] = _enum_member2num[id(v)]\n'
  260. ' return tag + self._packer.pack(*pdata)\n'
  261. '\n'
  262. )
  263. self._write(
  264. 'class _EnumBase(enum.Enum):\n'
  265. ' @classmethod\n'
  266. ' def __normalize(cls, val):\n'
  267. ' if isinstance(val, str):\n'
  268. ' if not hasattr(cls, "__member_upper_dict__"):\n'
  269. ' cls.__member_upper_dict__ = {k.upper(): v\n'
  270. ' for k, v in cls.__members__.items()}\n'
  271. ' val = cls.__member_upper_dict__.get(val.upper(),val)\n'
  272. ' return val\n'
  273. ' @classmethod\n'
  274. ' def convert(cls, val):\n'
  275. ' val = cls.__normalize(val)\n'
  276. ' if isinstance(val, cls):\n'
  277. ' return val\n'
  278. ' return cls(val)\n'
  279. ' @classmethod\n'
  280. ' def _missing_(cls, value):\n'
  281. ' vnorm = cls.__normalize(value)\n'
  282. ' if vnorm is not value:\n'
  283. ' return cls(vnorm)\n'
  284. ' return super()._missing_(value)\n'
  285. '\n'
  286. )
  287. if not self._imperative:
  288. self._write(
  289. 'def _as_dtype_num(dtype):\n'
  290. ' import megbrain.mgb as m\n'
  291. ' return m._get_dtype_num(dtype)\n'
  292. '\n'
  293. )
  294. self._write(
  295. 'def _as_serialized_dtype(dtype):\n'
  296. ' import megbrain.mgb as m\n'
  297. ' return m._get_serialized_dtype(dtype)\n'
  298. '\n'
  299. )
  300. else:
  301. self._write(
  302. 'def _as_dtype_num(dtype):\n'
  303. ' import megengine.core._imperative_rt.utils as m\n'
  304. ' return m._get_dtype_num(dtype)\n'
  305. '\n'
  306. )
  307. self._write(
  308. 'def _as_serialized_dtype(dtype):\n'
  309. ' import megengine.core._imperative_rt.utils as m\n'
  310. ' return m._get_serialized_dtype(dtype)\n'
  311. '\n'
  312. )
  313. self._process(defs)
  314. self._write(
  315. '''
  316. class SerializedDType(_ParamDefBase):
  317. TAG = FakeSerializedDType.TAG
  318. __slots__ = ['dtype']
  319. class IdentityPacker:
  320. def pack(self, *args):
  321. assert all([isinstance(x, bytes) for x in args])
  322. return b''.join(args)
  323. _packer = IdentityPacker()
  324. def __init__(self, dtype):
  325. """
  326. :type dtype: :class:`np.dtype` compatible
  327. """
  328. self.dtype = _as_serialized_dtype(dtype)
  329. '''
  330. )
  331. self._write('_enum_member2num = {\n %s}',
  332. ',\n '.join(self._enum_member2num))
  333. def _write_doc(self, doc):
  334. assert isinstance(doc, member_defs.Doc)
  335. if not doc.doc:
  336. return
  337. if doc.no_reformat:
  338. self._write('"""')
  339. for i in doc.raw_lines:
  340. self._write(i)
  341. self._write('"""')
  342. return
  343. doc = doc.doc.replace('\n', ' ')
  344. textwidth = 80 - len(self._cur_indent)
  345. self._write('"""')
  346. for i in textwrap.wrap(doc, textwidth):
  347. self._write(i)
  348. self._write('"""')
  349. def _on_param_begin(self, p):
  350. self._cur_param_name = str(p.name)
  351. self._cur_fields = []
  352. self._cur_enum_names = []
  353. self._write('class %s(_ParamDefBase):', p.name, indent=1)
  354. self._write_doc(p.name)
  355. self._write('TAG = %d', p.tag)
  356. def _on_param_end(self, p):
  357. # gen slots and packer
  358. self._write('__slots__ = [%s]', ', '.join(
  359. map('"{.name}"'.format, self._cur_fields)))
  360. struct_fmt = ''.join(i.fmt for i in self._cur_fields)
  361. if not struct_fmt:
  362. struct_fmt = 'x'
  363. else:
  364. # add padding at end
  365. max_t = max(struct_fmt, key=struct.calcsize)
  366. struct_fmt += '0{}'.format(max_t)
  367. self._write('_packer = struct.Struct("%s")', struct_fmt)
  368. # gen __init__ signature
  369. self._write('def __init__(%s):',
  370. ', '.join(['self'] +
  371. list('{}={}'.format(i.name, i.default)
  372. for i in self._cur_fields)),
  373. indent=1)
  374. # gen __init__ doc
  375. self._write('"""')
  376. for i in self._cur_fields:
  377. self._write(':type {}: :class:`.{}`'.format(i.name, i.type))
  378. if i.doc:
  379. self._write(':param {}: {}'.format(i.name, i.doc))
  380. self._write('"""')
  381. # gen cvt in __init__
  382. for i in self._cur_fields:
  383. self._write('self.%s = %s', i.name, i.cvt)
  384. self._unindent()
  385. self._unindent()
  386. self._write('')
  387. def _on_member_enum(self, e):
  388. qualname = '{}.{}'.format(self._cur_param_name, e.name)
  389. self._write('class %s(_EnumBase):', e.name, indent=1)
  390. self._write_doc(e.name)
  391. for idx, emem in enumerate(e.members):
  392. self._write('%s = "%s"', emem, emem)
  393. self._write_doc(emem)
  394. if e.combined:
  395. self._enum_member2num.append('id({}.{}):{}'.format(
  396. qualname, emem, 1<<idx))
  397. else:
  398. self._enum_member2num.append('id({}.{}):{}'.format(
  399. qualname, emem, idx))
  400. for emem, emem_alis in e.member_alias:
  401. self._write('%s = %s', emem_alis, emem)
  402. self._unindent()
  403. self._write('')
  404. self._cur_fields.append(self.FieldDef(
  405. name=e.name_field,
  406. cvt='{}.convert({})'.format(qualname, e.name_field),
  407. fmt='I',
  408. default="'{}'".format(e.members[e.default]),
  409. type=qualname,
  410. doc=None))
  411. def _on_member_enum_alias(self, e):
  412. self._write('%s = %s.%s', e.name, e.src_class, e.src_name)
  413. s = e.src_enum
  414. qualname = '{}.{}'.format(e.src_class, e.src_name)
  415. self._cur_fields.append(self.FieldDef(
  416. name=e.name_field,
  417. cvt='{}.convert({})'.format(qualname, e.name_field),
  418. fmt='I',
  419. default="'{}'".format(s.members[e.get_default()]),
  420. type=qualname,
  421. doc=None))
  422. def _get_py_default(self, cppdefault):
  423. if not isinstance(cppdefault, str):
  424. return cppdefault
  425. d = cppdefault
  426. if d.endswith('f'): # 1.f
  427. return d[:-1]
  428. if d.endswith('ull'):
  429. return d[:-3]
  430. if d == 'false':
  431. return 'False'
  432. if d == 'true':
  433. return 'True'
  434. if d.startswith('DTypeEnum::'):
  435. return '"{}"'.format(d.split(':')[2].lower())
  436. return d
  437. def _on_member_field(self, f):
  438. d = self._get_py_default(f.default)
  439. self._cur_fields.append(self.FieldDef(
  440. name=f.name,
  441. cvt='{}({})'.format(f.dtype.pycvt, f.name),
  442. fmt=f.dtype.pyfmt,
  443. default=d,
  444. type=f.dtype.pycvt,
  445. doc=f.name.doc
  446. ))
  447. def _on_const_field(self, f):
  448. d = self._get_py_default(f.default)
  449. self._write_doc(f.name)
  450. self._write('%s = %s', f.name, d)
  451. class CPPWriter(IndentWriterBase):
  452. _param_namespace = 'param'
  453. _ctor_args = None
  454. """list of (text in func param, var name); func param name must be var name
  455. appended by an underscore"""
  456. _non_static_members = None
  457. def __call__(self, fout, defs):
  458. super().__call__(fout)
  459. self._write('// %s', self._get_header())
  460. self._write('#pragma once')
  461. self._write('#include "megdnn/dtype.h"')
  462. self._write('#include <stdint.h>')
  463. if self._param_namespace == 'param':
  464. self._write('#include <string.h>')
  465. self._write('namespace megdnn {')
  466. self._write('namespace %s {', self._param_namespace)
  467. self._process(defs)
  468. self._write('} // namespace megdnn')
  469. self._write('} // namespace %s', self._param_namespace)
  470. self._write('// vim: syntax=cpp.doxygen')
  471. def _write_doc(self, doc):
  472. assert isinstance(doc, member_defs.Doc)
  473. if not doc.doc:
  474. return
  475. if doc.no_reformat:
  476. self._write('/*')
  477. for i in doc.raw_lines:
  478. self._write('* ' + i)
  479. self._write('*/')
  480. return
  481. doc = doc.doc.replace('\n', ' ')
  482. textwidth = 80 - len(self._cur_indent) - 4
  483. if len(doc) <= textwidth:
  484. self._write('//! ' + doc)
  485. return
  486. self._write('/*!')
  487. for i in textwrap.wrap(doc, textwidth):
  488. self._write(' * ' + i)
  489. self._write(' */')
  490. def _on_param_begin(self, p):
  491. self._write_doc(p.name)
  492. self._write('struct %s {', p.name, indent=1)
  493. self._write('static MEGDNN_CONSTEXPR uint32_t TAG = %du;', p.tag)
  494. self._ctor_args = []
  495. self._non_static_members = []
  496. def _add_ctor_args(self, typename, default, varname):
  497. self._ctor_args.append((
  498. '{} {}_={}'.format(typename, varname, default),
  499. varname))
  500. def _on_param_end(self, p):
  501. '''
  502. MegDNN param structures are not packed and we need to initialize the structure
  503. paddings to zero or it would break MegBrain hash system. We do memset(0) in default
  504. ctor and use a trick, wrapping non-static members in a anonymous union which would
  505. copy the object representation in its default copy/move ctor, for copy/move ctor.
  506. > The implicitly-defined copy/move constructor for a non-union class X performs
  507. > a memberwise copy/move of its bases and members. [class.copy.ctor 14]
  508. > The implicitly-defined copy/move constructor for a union X copies the object
  509. > representation (6.9) of X. [class.copy.ctor 15]
  510. '''
  511. if self._non_static_members:
  512. self._write('union { struct {')
  513. for i in self._non_static_members:
  514. if isinstance(i, member_defs.Field):
  515. self._write_doc(i.name)
  516. self._write('%s%s %s;', i.dtype.cname_attr, i.dtype.cname, i.name)
  517. else:
  518. assert isinstance(i, (member_defs.Enum, member_defs.EnumAlias))
  519. self._write('%s %s;', i.name, i.name_field)
  520. self._write('}; };')
  521. if self._ctor_args:
  522. pdefs, varnames = zip(*self._ctor_args)
  523. self._write('%s(%s) {', p.name, ', '.join(pdefs), indent=1)
  524. self._write('memset(this, 0, sizeof(*this));')
  525. for var in varnames:
  526. self._write('this->%s = %s_;', var, var)
  527. self._write('}', indent=-1)
  528. self._write('};\n', indent=-1)
  529. def _on_member_enum(self, e):
  530. self._write_doc(e.name)
  531. self._write('enum class %s: uint32_t {', e.name, indent=1)
  532. for idx, i in enumerate(e.members):
  533. self._write_doc(i)
  534. v = '{} = {}'.format(i, idx)
  535. if e.combined:
  536. v = '{} = 1 << {}'.format(i, idx)
  537. if i is not e.members[-1] or e.member_alias:
  538. v += ','
  539. self._write(v)
  540. for mem, alias in e.member_alias:
  541. self._write('%s = %s,', alias, mem)
  542. self._write('};', indent=-1)
  543. self._non_static_members.append(e)
  544. self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;',
  545. str(e.name).upper(), len(e.members))
  546. self._add_ctor_args(e.name,
  547. '{}::{}'.format(e.name, e.members[e.default]),
  548. e.name_field)
  549. def _on_member_enum_alias(self, e):
  550. s = e.src_enum
  551. self._write('using %s = %s::%s;', e.name, e.src_class, e.src_name)
  552. self._non_static_members.append(e)
  553. self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;',
  554. str(e.name).upper(), len(s.members))
  555. self._add_ctor_args(e.name,
  556. '{}::{}'.format(e.name,
  557. s.members[e.get_default()]),
  558. e.name_field)
  559. def _on_member_field(self, f):
  560. self._non_static_members.append(f)
  561. self._add_ctor_args(f.dtype.cname, f.default, f.name)
  562. def _on_const_field(self, f):
  563. self._write_doc(f.name)
  564. if 'int' in f.dtype.cname:
  565. self._write('static constexpr %s%s %s = %s;', f.dtype.cname_attr, f.dtype.cname, f.name, f.default)
  566. else:
  567. self._write('static const %s%s %s = %s;', f.dtype.cname_attr, f.dtype.cname, f.name, f.default)
  568. class CPPEnumValueWriter(CPPWriter):
  569. _param_namespace = 'param_enumv'
  570. def _on_member_enum(self, e):
  571. self._write_doc(e.name)
  572. self._write('struct %s {', e.name, indent=1)
  573. for idx, val in enumerate(e.members):
  574. self._write_doc(val)
  575. self._write('static const uint32_t %s = %d;', val, idx)
  576. for mem, alias in e.member_alias:
  577. self._write('static const uint32_t %s = %s;', alias, mem)
  578. self._write('};', indent=-1)
  579. def _on_member_enum_alias(self, e):
  580. s = e.src_enum
  581. self._write('typedef %s::%s %s;', e.src_class, e.src_name, e.name)
  582. def _on_member_field(self, f):
  583. pass
  584. def _on_const_field(self, f):
  585. pass
  586. class CPPEnumItemWriter(WriterBase):
  587. _class_name = None
  588. _enum_name = None
  589. _enable = False
  590. def __init__(self, enum_def):
  591. self._class_name, self._enum_name = enum_def.split(':')
  592. def __call__(self, fout, defs):
  593. super().__call__(fout)
  594. self._process(defs)
  595. def _on_param_begin(self, p):
  596. self._enable = p.name == self._class_name
  597. def _on_member_enum(self, e):
  598. if self._enable and e.name == self._enum_name:
  599. for i in e.members:
  600. self._fout.write('{}\n'.format(i))
  601. class CPPParamJsonFuncWriter(IndentWriterBase):
  602. _param_namespace = 'param'
  603. _param_name = None
  604. _items = None
  605. def _write_json_item(self, json_cls, field):
  606. cls2ctype = {
  607. 'NumberInt': 'int64_t',
  608. 'Number': 'double',
  609. 'Bool': 'bool',
  610. }
  611. self._items.append('{"%s", json::%s::make(static_cast<%s>(p.%s))},' % (
  612. field, json_cls, cls2ctype[json_cls], field))
  613. def __call__(self, fout, defs):
  614. super().__call__(fout)
  615. self._write('// %s', self._get_header())
  616. self._write('// this file can only be included in '
  617. 'megbrain/src/plugin/impl/opr_footprint.cpp\n'
  618. '// please do not include it directly')
  619. self._write('#include "megdnn/opr_param_defs.h"')
  620. self._write('#pragma once')
  621. self._write('using namespace megdnn;')
  622. self._write('namespace mgb {')
  623. self._write('namespace opr {')
  624. self._write('template<class OprParam>')
  625. self._write('std::shared_ptr<mgb::json::Value> opr_param_to_json(const OprParam &param);')
  626. self._process(defs)
  627. self._write('} // namespace opr')
  628. self._write('} // namespace mgb')
  629. self._write('\n// vim: syntax=cpp.doxygen')
  630. def _on_param_begin(self, p):
  631. self._write('template<>', indent=0)
  632. self._write(
  633. 'std::shared_ptr<mgb::json::Value> opr_param_to_json(const param::%s &p) {',
  634. p.name, indent=1)
  635. self._param_name = 'param::{}'.format(p.name)
  636. self._items = []
  637. def _on_param_end(self, p):
  638. self._write('return json::Object::make({', indent=1)
  639. for i in self._items:
  640. self._write(i, indent=0)
  641. self._write('});', indent=-1)
  642. self._write('}', indent=-1)
  643. def _on_member_enum(self, e):
  644. self._write('auto %s2str = [](const %s::%s arg) -> std::string {',
  645. e.name, self._param_name, e.name, indent=1)
  646. self._write('switch (arg) {', indent=1)
  647. enum2str = []
  648. if isinstance(e, member_defs.EnumAlias):
  649. members = e.src_enum.members
  650. else:
  651. members = e.members
  652. for idx, i in enumerate(members):
  653. self._write('case %s::%s::%s: return "%s";',
  654. self._param_name, e.name, i, i, indent=0)
  655. self._write('default: mgb_throw(MegBrainError, "Invalid %s::%s:%%d", static_cast<int>(arg));',
  656. self._param_name, e.name, indent=0)
  657. self._write('}', indent=-1)
  658. self._write('};', indent=-1)
  659. self._items.append('{"%s", json::String::make(%s2str(p.%s))},' % (
  660. e.name_field, e.name, e.name_field))
  661. def _on_member_enum_alias(self, e):
  662. self._on_member_enum(e)
  663. def _on_member_field(self, f):
  664. self._write_json_item(f.dtype.cppjson, f.name)
  665. def _on_const_field(self, f):
  666. pass
  667. def main():
  668. parser = argparse.ArgumentParser(
  669. 'generate opr param defs from description file')
  670. parser.add_argument('--enumv', action='store_true',
  671. help='generate c++03 compatible code which only '
  672. 'contains enum values')
  673. parser.add_argument('-t', '--type', choices=['c++', 'py'], default='c++',
  674. help='output type')
  675. parser.add_argument('--write-enum-items',
  676. help='write enum item names to output file; argument '
  677. 'should be given in the CLASS:ENUM format')
  678. parser.add_argument('--write-cppjson',
  679. help='generate megbrain json serialization implemention'
  680. 'cpp file')
  681. parser.add_argument('input')
  682. parser.add_argument('output')
  683. parser.add_argument('--imperative', action='store_true',
  684. help='generate files for imperatvie ')
  685. args = parser.parse_args()
  686. for_imperative = args.imperative
  687. with open(args.input) as fin:
  688. inputs = fin.read()
  689. exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc})
  690. input_hash = hashlib.sha256()
  691. input_hash.update(inputs.encode(encoding='UTF-8'))
  692. input_hash = input_hash.hexdigest()
  693. if args.type == 'py':
  694. writer = PyWriter(for_imperative=for_imperative)
  695. else:
  696. assert args.type == 'c++'
  697. if args.enumv:
  698. writer = CPPEnumValueWriter()
  699. elif args.write_enum_items:
  700. writer = CPPEnumItemWriter(args.write_enum_items)
  701. else:
  702. writer = CPPWriter()
  703. with open(args.output, 'w') as fout:
  704. writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs)
  705. if args.write_cppjson:
  706. writer = CPPParamJsonFuncWriter()
  707. with open(args.write_cppjson, 'w') as fout:
  708. writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs)
  709. if __name__ == '__main__':
  710. main()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台