You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import numpy as np
  11. from . import mgb
  12. from .exc import MegBrainError
  13. from .mgb import SharedND, SymbolVar
  14. from .opr_param_defs import OptionalAxisV1
  15. def canonize_reshape(inputs, *, comp_graph, config):
  16. src, tshape = inputs
  17. tshape = cvt_to_shape_desc(tshape, src, comp_graph, config)
  18. return src, tshape
  19. def canonize_shape_input(inputs, *, comp_graph, config):
  20. assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
  21. return [cvt_to_shape_desc(inputs[0], None, comp_graph, config)]
  22. def cvt_to_shape_desc(val, inpvar, graph, config):
  23. """convert some python object to a :class:`SymbolVar` that describes tensor
  24. shape
  25. :param val: the python object to be converted from
  26. :param inpvar, graph, config: provide graph and comp node information; can
  27. be None if not known. Either input or (graph, config) must be provided.
  28. :return: a new var corresponding to *val*
  29. :rtype: :class:`.SymbolVar`
  30. """
  31. if hasattr(val, "__mgb_symvar__"):
  32. val = val.__mgb_symvar__()
  33. elif hasattr(val, "symvar"):
  34. val = val.symvar
  35. if isinstance(val, SymbolVar):
  36. return val
  37. if not isinstance(val, collections.Iterable):
  38. val = [val]
  39. components = []
  40. has_sym = False
  41. for i in val:
  42. if hasattr(i, "__mgb_symvar__"):
  43. i = i.__mgb_symvar__()
  44. elif hasattr(i, "symvar"):
  45. i = i.symvar
  46. if isinstance(i, SymbolVar):
  47. has_sym = True
  48. components.append(i)
  49. else:
  50. assert isinstance(i, int), (
  51. "shape desc could contain either int or SymbolVar, got {}"
  52. " actually".format(repr(i))
  53. )
  54. components.append(i)
  55. assert components, "shape desc could not be empty"
  56. if inpvar is not None:
  57. assert isinstance(inpvar, SymbolVar)
  58. if graph is None:
  59. graph = inpvar.owner_graph
  60. else:
  61. assert graph == inpvar.owner_graph
  62. config = mgb.make_opr_config(comp_node=inpvar.comp_node)
  63. else:
  64. assert isinstance(graph, mgb.CompGraph), "graph must be provided"
  65. assert isinstance(config, mgb.OperatorNodeConfig)
  66. if not has_sym:
  67. shape = np.ascontiguousarray(components, dtype=np.int32)
  68. assert np.all(shape == components), "failed to convert to shape: {}".format(
  69. components
  70. )
  71. return mgb._make_immutable(graph, shape, None, config)
  72. for idx, v in enumerate(components):
  73. if not isinstance(v, SymbolVar):
  74. vi = int(v)
  75. assert vi == v, "could not convert {} to int".format(v)
  76. components[idx] = mgb._make_immutable(graph, vi, None, config)
  77. from . import opr as O
  78. return O.concat(components, axis=0, config=config)
  79. def canonize_input_vars(inputs, *, comp_graph, config):
  80. """convert immediate numbers and SharedND to SymbolVar in inputs; at least
  81. one of the inputs must be SymbolVar, so comp node and comp graph can
  82. beinferred
  83. :return: list of converted vars
  84. """
  85. from . import make_immutable
  86. if (
  87. isinstance(inputs, (list, tuple))
  88. and len(inputs) == 1
  89. and isinstance(inputs[0], (list, tuple))
  90. ):
  91. # handle the case when a list is passed to a function with
  92. # variable-length argument (e.g. concat has signature concat(*inputs)
  93. # and is called with concat([a, b]))
  94. inputs = inputs[0]
  95. if isinstance(inputs, SymbolVar):
  96. return [inputs]
  97. old_inputs = inputs
  98. inputs = []
  99. get_comp_node = None
  100. need_cvt = False
  101. for i in old_inputs:
  102. if isinstance(i, SymbolVar):
  103. get_comp_node = lambda cn=i.comp_node: cn
  104. if comp_graph is not None:
  105. assert comp_graph == i.owner_graph
  106. else:
  107. comp_graph = i.owner_graph
  108. else:
  109. need_cvt = True
  110. inputs.append(i)
  111. if not need_cvt:
  112. return inputs
  113. if get_comp_node is None:
  114. def get_comp_node():
  115. nonlocal get_comp_node
  116. cn = config.require_comp_node()
  117. get_comp_node = lambda: cn
  118. return cn
  119. for idx, var in enumerate(inputs):
  120. if not isinstance(var, SymbolVar):
  121. if isinstance(var, SharedND):
  122. var = var.symvar(comp_graph)
  123. elif isinstance(var, mgb.SharedScalar):
  124. var = var._as_sym_var(comp_graph, get_comp_node())
  125. elif hasattr(var, "__mgb_symvar__"):
  126. try:
  127. cn = get_comp_node()
  128. except MegBrainError:
  129. cn = None
  130. var = var.__mgb_symvar__(comp_graph=comp_graph, comp_node=cn)
  131. elif hasattr(var, "symvar"):
  132. var = var.symvar
  133. else:
  134. var = make_immutable(get_comp_node(), comp_graph, var)
  135. inputs[idx] = var
  136. return inputs
  137. def cvt_to_vector_of_shape(shapes):
  138. """convert ``[[int]]`` to nested ``std::vector`` of ``size_t``"""
  139. ret = mgb._VectorTensorShape()
  140. for i in shapes:
  141. val = tuple(i)
  142. assert val and all(
  143. j > 0 and isinstance(j, int) for j in val
  144. ), "something returns bad shape in infer_shape(): {}".format(val)
  145. ret.push_back(val)
  146. return ret
  147. def cvt_to_opr_param_def(param, ptype, kwargs):
  148. if param is not None:
  149. if isinstance(param, ptype):
  150. return param
  151. param = [param]
  152. assert len(param) == len(
  153. ptype.__slots__
  154. ), "{} needs {} params, but {} are provided".format(
  155. ptype, len(ptype.__slots__), len(param)
  156. )
  157. return ptype(*param)
  158. ckw = {}
  159. for i in ptype.__slots__:
  160. val = kwargs.pop(i, ckw)
  161. if val is not ckw:
  162. ckw[i] = val
  163. return ptype(**ckw)
  164. def cvt_getitem_to_idx_desc(inpvar, tuple_val, *, allow_newaxis=True):
  165. """convert ``__getitem__`` args to index desc
  166. :return: ``(new_var, index_desc)`` where new_var is inpvar with
  167. ``np.newaxis`` applied; note that ``index_desc`` can be ``None``.
  168. """
  169. assert isinstance(inpvar, SymbolVar), "bad input: {!r}".format(inpvar)
  170. if not isinstance(tuple_val, tuple):
  171. tuple_val = (tuple_val,)
  172. axis_indexer = mgb._VectorAxisIndexer()
  173. config = mgb.make_opr_config(comp_node=inpvar.comp_node)
  174. graph = inpvar.owner_graph
  175. def as_symvar(v, *, allow_list=True):
  176. if isinstance(v, SymbolVar):
  177. return v
  178. vi = np.ascontiguousarray(v, dtype=np.int32)
  179. assert np.abs(vi - v).max() == 0, "bad index: {!r}".format(v)
  180. return mgb._make_immutable(graph, vi, None, config)
  181. def _s(v): # convert slice item
  182. if v is None:
  183. return SymbolVar()
  184. return as_symvar(v, allow_list=False)
  185. new_axes = []
  186. cur_axis = -1
  187. for i_idx, i in enumerate(tuple_val):
  188. cur_axis += 1
  189. if i is np.newaxis:
  190. if cur_axis >= 0:
  191. new_axes.append(cur_axis)
  192. continue
  193. if i is Ellipsis:
  194. cur_axis = -1
  195. for j in tuple_val[:i_idx:-1]:
  196. if j is Ellipsis:
  197. raise IndexError("only one ellipsis is allowed")
  198. if j is np.newaxis:
  199. new_axes.append(cur_axis)
  200. cur_axis -= 1
  201. continue
  202. if isinstance(i, slice):
  203. if i.start is None and i.stop is None and i.step is None:
  204. continue
  205. cur = mgb._AxisIndexer.make_interval(
  206. cur_axis, _s(i.start), _s(i.stop), _s(i.step)
  207. )
  208. else:
  209. cur = mgb._AxisIndexer.make_index(cur_axis, as_symvar(i))
  210. axis_indexer.push_back(cur)
  211. if new_axes:
  212. if not allow_newaxis:
  213. raise IndexError("newaxis is not allowed here")
  214. inpvar = mgb._Opr.add_axis(inpvar, new_axes, mgb.make_opr_config())
  215. if axis_indexer.empty():
  216. axis_indexer = None
  217. return inpvar, axis_indexer
  218. def cvt_to_reshape_unspec_axis(unspec_axis, tshape):
  219. assert isinstance(unspec_axis, OptionalAxisV1), repr(unspec_axis)
  220. unspec_axis = unspec_axis.axis
  221. assert abs(unspec_axis) <= OptionalAxisV1.MAX_NDIM
  222. if not isinstance(tshape, SymbolVar):
  223. for idx, val in enumerate(tshape):
  224. if val == -1:
  225. assert (
  226. unspec_axis == OptionalAxisV1.INVALID_AXIS
  227. ), "multiple unknown dimensions for reshape"
  228. unspec_axis = idx
  229. return OptionalAxisV1(unspec_axis)
  230. def gen_config(name, comp_node, config, output_dtype=None):
  231. if config is None:
  232. config = mgb.make_opr_config(name, comp_node, output_dtype)
  233. else:
  234. assert isinstance(config, mgb.OperatorNodeConfig)
  235. assert name is None and comp_node is None
  236. return config
  237. def cvt_opr_result(rst, *, explode_single=True):
  238. """:param explode_single: whether to return the content of a single-item
  239. list rather thatn the list itself"""
  240. if not isinstance(rst, mgb.SymbolVar):
  241. assert isinstance(rst, (list, tuple))
  242. if len(rst) == 1 and explode_single:
  243. return cvt_opr_result(rst[0])
  244. return tuple(map(cvt_opr_result, rst))
  245. if not rst.valid:
  246. return None
  247. # TODO Because the __init__ of SwigObject can not be modified to keep the
  248. # reference of graph, we get owner graph explicitly here. The correct
  249. # handling is moving the reference to SwigWrapper, but it is unsupported to
  250. # add a member variable to SwigWrapper, so we should wrap the SymbolVar
  251. # manually in megbrain_wrap.h
  252. rst.owner_graph
  253. f32 = np.float32
  254. if not hasattr(cvt_opr_result, "_cvt_to_float32"):
  255. import os
  256. from .logconf import get_logger
  257. cvt_opr_result._cvt_to_float32 = os.getenv("MGB_ALL_FLOAT32")
  258. if cvt_opr_result._cvt_to_float32:
  259. get_logger().warn(
  260. "\n"
  261. "+=====================================================+\n"
  262. "| MGB_ALL_FLOAT32 is set, so all megbrain opr result |\n"
  263. "| would to converted to float32; this should only be |\n"
  264. "| used for loading old models. |\n"
  265. "+=====================================================+"
  266. )
  267. if cvt_opr_result._cvt_to_float32 and rst.dtype != f32:
  268. rst = rst.astype(f32)
  269. return rst

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台