You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

craniotome.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. """used for creating a megbrain operator from python"""
  10. import copy
  11. import itertools
  12. from abc import ABCMeta, abstractmethod, abstractproperty
  13. from . import helper as _helper
  14. from . import mgb as _mgb
  15. class _CraniotomeBaseMeta(ABCMeta):
  16. _base_created = False
  17. def __init__(cls, name, bases, member_dict):
  18. if _CraniotomeBaseMeta._base_created:
  19. assert "__init__" not in member_dict, (
  20. "Craniotome operators should not overwrite __init__ method; "
  21. "use setup() instead."
  22. )
  23. forbidden = set(
  24. k for k in dir(CraniotomeBase) if k[0] == "_" and k[1] != "_"
  25. )
  26. forbidden.add("get_io_vars")
  27. check_key = member_dict.get("__check_key__", True)
  28. whitelist = ["__classcell__"]
  29. for k in member_dict.keys():
  30. assert k not in forbidden, "{} could not be overwritten".format(k)
  31. if (
  32. check_key
  33. and k.startswith("__")
  34. and k.endswith("__")
  35. and k not in whitelist
  36. and not hasattr(CraniotomeBase, k)
  37. ):
  38. raise KeyError(
  39. "name {} in class {} does not exist in the baseclass".format(
  40. k, name
  41. )
  42. )
  43. else:
  44. _CraniotomeBaseMeta._base_created = True
  45. super().__init__(name, bases, member_dict)
  46. class CraniotomeBase(_mgb.CraniotomeDesc, metaclass=_CraniotomeBaseMeta):
  47. """base class used for extending megbrain core operators in python
  48. Note: all names starting and ending with two underscores in the subclasses
  49. would be checked and KeyError would be raised if the name does not exist in
  50. the base class. This behavor can be disabled by setting ``__check_key__``
  51. to ``False`` (see the testcase for more details)
  52. """
  53. # methods and attributes to be overwritten by subclasses
  54. __expand_single_outputs__ = True
  55. """if :attr:`__nr_outputs__` is 1, whether to return a single
  56. :class:`.SymbolVar` instead of a tuple in :meth:`make`"""
  57. __is_dynamic_output_shape__ = False
  58. """whether output shape could not be inferred from input shape. If value of
  59. this attribute is ``False``, :meth:`infer_shape` must be implemented. If
  60. this attribute is ``True`` but the operator has no inputs, then
  61. :meth:`infer_shape` would also be called to infer output shape before
  62. operator execution.
  63. """
  64. __disable_sys_mem_alloc__ = False
  65. """whether to disable system memory allocator. This is used when
  66. :attr:`__is_dynamic_output_shape__` is ``False`` but the output memory
  67. should not be managed by megbrain system (so it can be forwarded from
  68. external buffer)"""
  69. __allow_duplicate__ = True
  70. """whether this operator can be duplicated (e.g. used in sublinear
  71. memory)"""
  72. __allow_empty_out__ = False
  73. """whether empty output shape is allowed; if it is set as ``False``, then
  74. an exception would be raised if output var is empty to prevent erroneously
  75. forgetting initializing output vars"""
  76. @abstractproperty
  77. def __nr_inputs__(self):
  78. """number of input vars"""
  79. @abstractproperty
  80. def __nr_outputs__(self):
  81. """number of output vars"""
  82. @abstractmethod
  83. def execute(self, inputs, outputs):
  84. """execute the operator, read values from *inputs* by calling
  85. :meth:`.CompGraphCallbackValueProxy.get_value` and write results into
  86. *outputs* by calling :meth:`.SharedND.set_value`
  87. :param inputs: values for each input var
  88. :type inputs: tuple of :class:`.CompGraphCallbackValueProxy`
  89. :param outputs: values for each output var
  90. :type outputs: tuple of :class:`.SharedND`
  91. """
  92. def setup(self):
  93. """overwritten by subclass to accept kwargs passed to :meth:`make` to
  94. setup the operator"""
  95. def infer_shape(self, inp_shapes):
  96. """infer output shape from input shapes
  97. :type inp_shapes: tuple of tuple of ints
  98. :param inp_shapes: input shapes for each input var
  99. :rtype: tuple of tuple of ints
  100. :return: output shapes for each output var
  101. """
  102. raise NotImplementedError(
  103. "{}: infer_shape() not implemented; for operators with dynamic "
  104. "output shape, __is_dynamic_output_shape__ should be set to True".format(
  105. self
  106. )
  107. )
  108. def grad(self, wrt_idx, inputs, outputs, out_grad):
  109. """compute symbolic gradient; should be overwritten by differentiable
  110. subclasses
  111. :type wrt_idx: int
  112. :param wrt_idx: the input var with respect to which the gradient should
  113. be computed; please also see the notes below
  114. :type inputs: tuple of :class:`.SymbolVar`
  115. :param inputs: input symbol vars
  116. :type outputs: tuple of :class:`.SymbolVar`
  117. :param outputs: output symbol vars
  118. :type out_grad: tuple of (:class:`.SymbolVar` or None)
  119. :param out_grad: gradients of loss with respect to each output var
  120. .. note::
  121. In case when loss does not depend on some var (i.e. zero grad),
  122. the corresponding value in *out_grad* would be ``None``. It is
  123. guaranteed that at least one element in *out_grad* is not
  124. ``None``.
  125. .. note::
  126. This function can return either of the following:
  127. 1. Gradient of the input specified by ``wrt_idx``
  128. 2. A list containing gradients of all inputs. In this case,
  129. ``wrt_idx`` can be ignored.
  130. And the so called gradient can be either one of:
  131. 1. A :class:`.SymbolVar` representing the symbolic gradient
  132. value
  133. 2. ``0`` representing zero gradient
  134. """
  135. raise NotImplementedError("grad for {} not implemented".format(self))
  136. def init_output_dtype(self, input_dtypes):
  137. """infer output dtypes from input dtypes; return None to use default
  138. infer function in megbrain.
  139. .. note::
  140. This method must be implemented if there is no input var
  141. :param input_dtypes: input dtypes
  142. :type input_dtypes: list of :class:`numpy.dtype`
  143. :rtype: None or list of :class:`numpy.dtype`-compatible
  144. """
  145. def get_serialize_params(self):
  146. """get params for megbrain graph serialization. This function should
  147. return a list or tuple, containing one or two elements: the first
  148. element must be a string, representing the name passed to
  149. ``opr_loader_maker`` during deserializing; the second element, if
  150. exists, must be convertible to ``bytes`` and is used for dumping any
  151. extra opr params, which can be retrieved by ``load_buf_with_len``
  152. during deserializing.
  153. """
  154. raise NotImplementedError(
  155. "get_serialize_params() for {} not implemented".format(self)
  156. )
  157. def copy(self):
  158. """copy this craniotome descriptor; the default implementation creates
  159. a new object, and copies object ``__dict__``"""
  160. ret = type(self)()
  161. d0 = self.__dict__.copy()
  162. d0.pop("this")
  163. ret.__dict__.update(copy.deepcopy(d0))
  164. return ret
  165. def on_graph_compiled(self, used_outputs):
  166. """a callback that would be invoked when the graph is compiled; it
  167. would always have a matching :meth:`on_compiled_func_deleted` call
  168. :param used_outputs: indices of outputs that are needed for the
  169. computation
  170. :type used_outputs: ``tuple of int``
  171. """
  172. def on_compiled_func_deleted(self):
  173. """a callback that would be invoked when the compiled function is
  174. destructed; it would always have a matching :meth:`on_graph_compiled`
  175. call"""
  176. def get_io_vars(self):
  177. """get input vars, comp order dep vars and output vars
  178. :return: a dict with keys ``'input'``, ``'output'`` and
  179. ``'comp_order'`` that maps to corresponding list of vars
  180. """
  181. all_vars = list(self._get_all_io_vars())
  182. nr_inp = self.__nr_inputs__
  183. nr_out = self.__nr_outputs__
  184. nr_comp_order = self._get_nr_dev_comp_order_deps()
  185. s0 = nr_inp + nr_comp_order
  186. return dict(
  187. input=all_vars[:nr_inp],
  188. comp_order=all_vars[nr_inp:s0],
  189. output=all_vars[s0:],
  190. )
  191. @property
  192. def owner_opr_id(self):
  193. """ID of the operator that owns this descriptor"""
  194. return self._get_opr_id()
  195. @property
  196. def comp_node(self):
  197. """comp node on which this operator runs"""
  198. return self._get_comp_node()
  199. # below are methods that should not be changed
  200. def _hash(self):
  201. return int(hash(self)) % (1 << 64)
  202. def _setup_self(self, dst):
  203. dst.append(self)
  204. def _is_same(self, rhs):
  205. return bool(self == rhs)
  206. def _node_flag(self):
  207. return (
  208. (int(bool(self.__is_dynamic_output_shape__)) << 0)
  209. | (int(not self.__allow_duplicate__) << 1)
  210. | (int(bool(self.__allow_empty_out__)) << 2)
  211. | (int(bool(self.__disable_sys_mem_alloc__)) << 3)
  212. )
  213. def _get_opr_type_name(self):
  214. return str(self.__class__.__name__)
  215. def _get_nr_outputs(self):
  216. return int(self.__nr_outputs__)
  217. def _execute(self, inputs, outputs):
  218. inputs = tuple(inputs)
  219. outputs = tuple(outputs)
  220. if not self.__is_dynamic_output_shape__:
  221. out_shapes = [i.shape for i in outputs]
  222. self.execute(inputs, outputs)
  223. if not self.__is_dynamic_output_shape__:
  224. new_shapes = [i.shape for i in outputs]
  225. assert (
  226. out_shapes == new_shapes
  227. ), "output shape changed after executing {}: before={} after={}".format(
  228. self, out_shapes, new_shapes
  229. )
  230. def _infer_shape(self, inp_shapes):
  231. inp_shapes = tuple(tuple(map(int, i)) for i in inp_shapes)
  232. oshp_get = self.infer_shape(inp_shapes)
  233. assert (
  234. len(oshp_get) == self.__nr_outputs__
  235. ), "{}: expect {} outputs; got {}(val: {}) from infer_shape".format(
  236. self, self.__nr_outputs__, len(oshp_get), oshp_get
  237. )
  238. return _helper.cvt_to_vector_of_shape(oshp_get)
  239. def _grad(self, wrt_idx, inputs, outputs, out_grad):
  240. og = []
  241. for i in out_grad:
  242. if i.valid:
  243. og.append(i)
  244. else:
  245. og.append(None)
  246. rst = self.grad(int(wrt_idx), tuple(inputs), tuple(outputs), tuple(og))
  247. if not isinstance(rst, (list, tuple)):
  248. rst = [rst]
  249. else:
  250. assert len(rst) == len(
  251. inputs
  252. ), "{}: opr has {} inputs but {} grads are returned".format(
  253. self, len(inputs), len(rst)
  254. )
  255. for i in range(len(rst)):
  256. cur = rst[i]
  257. if cur is 0:
  258. rst[i] = _mgb.SymbolVar()
  259. else:
  260. assert isinstance(cur, _mgb.SymbolVar), (
  261. "{}: invalid grad result; it should be either "
  262. "0 or a SymbolVar, got {!r} instead".format(self, cur)
  263. )
  264. return rst
  265. def _get_nr_dev_comp_order_deps(self):
  266. return 0
  267. def _init_output_dtype(self, input_dtypes, ret):
  268. get = self.init_output_dtype(input_dtypes)
  269. if get is not None:
  270. assert isinstance(ret, (list, tuple)) and len(get) == len(ret)
  271. ret[:] = get
  272. return True
  273. assert self.__nr_inputs__, (
  274. "{}: init_output_dtype must be implemented "
  275. "if there is no input var".format(self)
  276. )
  277. return False
  278. def _setup_serialize_params(self, output):
  279. val = list(self.get_serialize_params())
  280. assert len(val) in [1, 2]
  281. name = val[0]
  282. assert isinstance(name, str)
  283. output.append(name)
  284. if len(val) == 2:
  285. output.append(bytes(val[1]))
  286. def _copy(self):
  287. ret = self.copy()
  288. assert type(ret) is type(
  289. self
  290. ), "copy() returned different type: src={} copied={}".format(
  291. type(self), type(ret)
  292. )
  293. assert ret is not self
  294. ret.__disown__()
  295. self._set_copy_result(ret)
  296. def _on_graph_compile_or_func_del(self, used_outputs):
  297. if used_outputs:
  298. self.on_graph_compiled(used_outputs)
  299. else:
  300. self.on_compiled_func_deleted()
  301. def __repr__(self):
  302. return "cranoiotome:{}".format(self.__class__.__name__)
  303. @classmethod
  304. def make(
  305. cls,
  306. *inputs,
  307. comp_graph=None,
  308. name=None,
  309. comp_node=None,
  310. config=None,
  311. dev_comp_order_deps=[],
  312. **kwargs
  313. ):
  314. """apply this operator on some input vars and return corresponding
  315. output vars
  316. :type inputs: tuple of :class:`.SymbolVar`
  317. :param inputs: input symvars; immediate values could also be accepted,
  318. as long as there is symvar to infer comp node and comp graph
  319. :param comp_graph: if there is no input vars, *comp_graph* must be
  320. provided to specify which computing graph to insert this operator
  321. :param dev_comp_order_deps: vars that must have been computed
  322. before executing this operator
  323. :param kwargs: extra keyword arguments to be passed to :meth:`setup` of
  324. this class
  325. :param name: name of the resulting operator
  326. :rtype: tuple of :class:`.SymbolVar`
  327. :return: output symvars
  328. """
  329. if not inputs and not dev_comp_order_deps:
  330. assert isinstance(
  331. comp_graph, _mgb.CompGraph
  332. ), "{}: comp_graph must be given if no inputs provided".format(self)
  333. desc = cls()
  334. desc.setup(**kwargs)
  335. assert (
  336. len(inputs) == desc.__nr_inputs__
  337. ), "{}: expected {} inputs, got {}".format(
  338. desc, desc.__nr_inputs__, len(inputs)
  339. )
  340. config = _helper.gen_config(name, comp_node, config)
  341. # get inp_vec
  342. inp_vec = _mgb._VectorSymbolVar()
  343. for i in _helper.canonize_input_vars(
  344. itertools.chain(inputs, dev_comp_order_deps),
  345. comp_graph=comp_graph,
  346. config=config,
  347. ):
  348. inp_vec.push_back(i)
  349. desc._get_nr_dev_comp_order_deps = lambda *, val=len(dev_comp_order_deps): val
  350. if comp_graph is not None:
  351. desc._get_comp_graph = lambda: comp_graph
  352. expand_single_outputs = desc.__expand_single_outputs__
  353. desc.__disown__()
  354. rst = _mgb.make_opr_from_craniotome_desc(desc, inp_vec, config)
  355. if expand_single_outputs and len(rst) == 1:
  356. return rst[0]
  357. return tuple(rst)
  358. def make_opr(cls):
  359. """decorator used to wrap a :class:`.CraniotomeBase` subclass and return
  360. its :meth:`~.CraniotomeBase.make` method
  361. """
  362. assert issubclass(cls, CraniotomeBase)
  363. return cls.make

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台