You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

expr.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import builtins
  9. import collections
  10. import copy
  11. import inspect
  12. import re
  13. from typing import Callable, Dict, List
  14. from ..core._imperative_rt import OpDef
  15. from ..core._imperative_rt.core2 import Tensor as RawTensor
  16. from ..core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing
  17. from ..core.ops.builtin import FakeQuant
  18. from ..core.ops.special import Const
  19. from ..module import Module
  20. from ..tensor import Parameter, Tensor
  21. from .module_tracer import active_module_tracer, module_tracer
  22. from .node import ModuleNode, Node, NodeMixin, TensorNode
  23. from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten
  24. from .serialization import get_opdef_state, load_opdef_from_state
  25. def rstrip(s: str, __chars: str):
  26. __chars = re.escape(__chars)
  27. s = re.sub(r"^(?P<left>.*?)(?:%s)+$" % __chars, "\g<left>", s)
  28. return s
  29. class Expr:
  30. """``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``."""
  31. __total_id = 0
  32. inputs = None # type: List[Node]
  33. outputs = None # type: List[Node]
  34. const_val = None # type: List[Any]
  35. arg_def = None # type: TreeDef
  36. out_def = None # type: TreeDef
  37. _top_graph = None # type: weakref.ReferenceType
  38. def __init__(self) -> None:
  39. self._id = Expr.__total_id
  40. Expr.__total_id += 1
  41. self._disable_remove = False
  42. def enable_remove(self):
  43. self._disable_remove = False
  44. def disable_remove(self):
  45. self._disable_remove = True
  46. def add_inputs(self, vals):
  47. if not isinstance(vals, collections.abc.Sequence):
  48. vals = (vals,)
  49. for val in vals:
  50. node = NodeMixin.get(val, None)
  51. if isinstance(node, (TensorNode, ModuleNode)):
  52. self.inputs.append(node)
  53. node.users.append(self)
  54. else:
  55. assert node is None
  56. assert _is_leaf(val) and _is_const_leaf(val)
  57. idx = len(self.inputs) + len(self.const_val)
  58. self.const_val.append((idx, val))
  59. def add_outputs(self, outputs):
  60. self.outputs = []
  61. if outputs is not None:
  62. if not isinstance(outputs, collections.Sequence):
  63. outputs = (outputs,)
  64. name = None
  65. orig_name = None
  66. if isinstance(self, CallMethod):
  67. name = self.inputs[0]._name
  68. orig_name = self.inputs[0]._orig_name
  69. assert isinstance(name, str), "The name of ({}) must be a str".format(
  70. self.inputs[0]
  71. )
  72. assert isinstance(
  73. orig_name, str
  74. ), "The orig_name of ({}) must be a str".format(self.inputs[0])
  75. name = rstrip(name, "_out")
  76. if self.method == "__call__":
  77. name += "_out"
  78. orig_name += "_out"
  79. else:
  80. strip_method = self.method.strip("_")
  81. name = "%s_out" % strip_method
  82. orig_name = name
  83. elif isinstance(self, CallFunction):
  84. name = self.func.__name__ + "_out"
  85. elif isinstance(self, Apply):
  86. name = str(self.opdef).lower() + "_out"
  87. for i in outputs:
  88. assert isinstance(i, RawTensor), "The output must be a Tensor"
  89. o_name = (
  90. active_module_tracer().current_scope()._create_unique_name(name)
  91. )
  92. self.outputs.append(
  93. NodeMixin.get_wrapped_type(i)(
  94. expr=self,
  95. name=o_name,
  96. orig_name=orig_name if orig_name else o_name,
  97. )
  98. )
  99. for i, node in zip(outputs, self.outputs,):
  100. NodeMixin.wrap_safe(i, node)
  101. def unflatten_args(self, inputs):
  102. if self.arg_def is not None:
  103. inputs = list(inputs)
  104. for idx, val in self.const_val:
  105. inputs.insert(idx, val)
  106. args, kwargs = self.arg_def.unflatten(inputs)
  107. return args, kwargs
  108. else:
  109. return inputs, {}
  110. def replace_inputs(self, repl_dict: Dict[Node, Node]):
  111. while repl_dict:
  112. node, repl_node = repl_dict.popitem()
  113. assert type(node) == type(repl_node)
  114. assert node in self.inputs, "({}) is not in the ({})".format(node, self)
  115. assert (
  116. repl_node.top_graph == node.top_graph
  117. ), "({}) and ({}) are not in the same graph".format(node, repl_node)
  118. graph = self.top_graph
  119. repl_expr_idx = graph._exprs.index(repl_node.expr)
  120. self_idx = graph._exprs.index(self)
  121. assert (
  122. repl_expr_idx < self_idx
  123. ), "({}) must be generated before ({})".format(repl_node, self)
  124. idx = self.inputs.index(node)
  125. self.inputs[idx] = repl_node
  126. user_idx = node.users.index(self)
  127. assert user_idx >= 0
  128. node.users.pop(user_idx)
  129. repl_node.users.append(self)
  130. @property
  131. def kwargs(self):
  132. _, kwargs = self.unflatten_args(self.inputs)
  133. return kwargs
  134. @property
  135. def args(self):
  136. args, _ = self.unflatten_args(self.inputs)
  137. return args
  138. @property
  139. def top_graph(self):
  140. if self._top_graph:
  141. return self._top_graph()
  142. return None
  143. def __getstate__(self):
  144. state = self.__dict__.copy()
  145. if "_top_graph" in state:
  146. state.pop("_top_graph")
  147. return state
  148. @classmethod
  149. def get_total_id(cls):
  150. return cls.__total_id
  151. @classmethod
  152. def set_total_id(cls, id: int = 0):
  153. assert isinstance(id, int)
  154. cls.__total_id = id
  155. # expr: None (i.e. fake expression which is used to mark input)
  156. class Input(Expr):
  157. name = None
  158. def __init__(self, name=None, type=None, orig_name=None):
  159. super().__init__()
  160. self.inputs = []
  161. node_cls = type if type else Node
  162. if orig_name is None:
  163. orig_name = name
  164. self.outputs = [
  165. node_cls(self, name=name, orig_name=orig_name),
  166. ]
  167. self.name = name
  168. @classmethod
  169. def make(cls, *args, **kwargs):
  170. expr = cls(*args, **kwargs)
  171. oup_node = expr.outputs[0]
  172. name = (
  173. active_module_tracer().current_scope()._create_unique_name(oup_node._name)
  174. )
  175. oup_node._name = name
  176. active_module_tracer().current_scope()._add_input(oup_node)
  177. return expr.outputs[0]
  178. def __repr__(self):
  179. return "%{}:\t{} = Input({})".format(self._id, self.outputs[0], self.name)
  180. # expr: outputs = getattr(inputs[0], self.name)
  181. class GetAttr(Expr):
  182. name = None
  183. def __init__(self, module, name, type=None, orig_name=None):
  184. super().__init__()
  185. assert isinstance(module, ModuleNode)
  186. self.inputs = [
  187. module,
  188. ]
  189. module.users.append(self)
  190. self.name = name
  191. node_cls = type if type else Node
  192. self.outputs = [
  193. node_cls(self, name=name, orig_name=orig_name),
  194. ]
  195. @classmethod
  196. def make(cls, *args, **kwargs):
  197. expr = cls(*args, **kwargs)
  198. module = expr.inputs[0]
  199. oup_name = expr.name
  200. while module._name != "self":
  201. oup_name = module._name + "_" + oup_name
  202. module = module.expr.inputs[0]
  203. oup_name = active_module_tracer().current_scope()._create_unique_name(oup_name)
  204. expr.outputs[0]._name = oup_name
  205. active_module_tracer().current_scope()._insert(expr)
  206. return expr.outputs[0]
  207. def interpret(self, *inputs):
  208. return (getattr(inputs[0], self.name),)
  209. def __repr__(self):
  210. out_type = "Tensor"
  211. if isinstance(self.outputs[0], ModuleNode):
  212. out_type = self.outputs[0].module_type.__name__
  213. return '%{}:\t{} = getattr({}, "{}") -> ({})'.format(
  214. self._id, self.outputs[0], self.inputs[0], self.name, out_type
  215. )
  216. # expr: outputs = inputs[0].__call__(*inputs[1:])
  217. class CallMethod(Expr):
  218. def __init__(self, node, method="__call__"):
  219. super().__init__()
  220. if isinstance(node, type):
  221. assert issubclass(node, Tensor)
  222. cls = Parameter if issubclass(node, Parameter) else Tensor
  223. self.inputs = []
  224. self.const_val = [(0, cls)]
  225. else:
  226. assert isinstance(node, (TensorNode, ModuleNode))
  227. node.users.append(self)
  228. self.inputs = [
  229. node,
  230. ]
  231. self.const_val = []
  232. self.method = method
  233. @classmethod
  234. def make(cls, *args, **kwargs):
  235. expr = cls(*args, **kwargs)
  236. active_module_tracer().current_scope()._insert(expr)
  237. return expr
  238. @property
  239. def graph(self):
  240. if isinstance(self.inputs[0], ModuleNode):
  241. m_node = self.inputs[0]
  242. if (
  243. hasattr(m_node.owner, "argdef_graph_map")
  244. and m_node.owner.argdef_graph_map
  245. ):
  246. assert self.arg_def in m_node.owner.argdef_graph_map
  247. return m_node.owner.argdef_graph_map[self.arg_def]
  248. return None
  249. def interpret(self, *inputs):
  250. args, kwargs = self.unflatten_args(inputs)
  251. obj = args[0]
  252. meth = getattr(obj, self.method)
  253. if inspect.ismethod(meth):
  254. args = args[1:]
  255. outputs = getattr(obj, self.method)(*args, **kwargs)
  256. if self.method == "__setitem__":
  257. outputs = obj
  258. if outputs is None:
  259. return outputs
  260. outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
  261. return outputs
  262. def __repr__(self):
  263. args = ", ".join(str(i) for i in self.args[1:])
  264. kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
  265. outputs = self.outputs
  266. if self.out_def:
  267. outputs = self.out_def.unflatten(outputs)
  268. method = ".%s" % self.method
  269. if method == ".__call__":
  270. method = ""
  271. return "%{}:\t{}{}{}({})".format(
  272. self._id,
  273. str(outputs) + " = " if outputs else "",
  274. self.args[0],
  275. method,
  276. ", ".join([args, kwargs]),
  277. )
  278. # expr: outputs = apply(self.opdef, *inputs)
  279. class Apply(Expr):
  280. opdef = None
  281. def __init__(self, opdef):
  282. super().__init__()
  283. assert isinstance(opdef, OpDef)
  284. self.opdef = opdef
  285. self.inputs = []
  286. @classmethod
  287. def make(cls, *args, **kwargs):
  288. expr = cls(*args, **kwargs)
  289. active_module_tracer().current_scope()._insert(expr)
  290. return expr
  291. def interpret(self, *inputs):
  292. return apply(self.opdef, *inputs)
  293. def __repr__(self):
  294. return "%{}:\t{} = {}({})".format(
  295. self._id,
  296. ", ".join(str(i) for i in self.outputs),
  297. self.opdef,
  298. ", ".join(str(i) for i in self.inputs),
  299. )
  300. def __getstate__(self):
  301. state = super().__getstate__()
  302. state["opdef"] = get_opdef_state(state["opdef"])
  303. return state
  304. def __setstate__(self, state):
  305. state["opdef"] = load_opdef_from_state(state["opdef"])
  306. for k, v in state.items():
  307. setattr(self, k, v)
  308. @classmethod
  309. def apply_module_trace_hook(cls, opdef, *inputs):
  310. for i in inputs:
  311. node = NodeMixin.get(i, None)
  312. if node is None: # capture as constant
  313. NodeMixin.wrap_safe(i, Constant.make(i))
  314. if isinstance(opdef, FakeQuant):
  315. inp_nodes = [NodeMixin.get(inputs[0])]
  316. for i in inputs[1:]:
  317. node = Constant.make(i)
  318. inp_nodes.append(node)
  319. apply_node = cls.make(opdef)
  320. for n in inp_nodes:
  321. n.users.append(apply_node)
  322. apply_node.inputs = inp_nodes
  323. else:
  324. apply_node = cls.make(opdef)
  325. apply_node.add_inputs(inputs)
  326. assert not apply_node.const_val
  327. unset_module_tracing()
  328. outputs = apply(opdef, *inputs)
  329. set_module_tracing()
  330. apply_node.add_outputs(outputs)
  331. for n, v in zip(apply_node.outputs, outputs):
  332. NodeMixin.wrap_safe(v, n)
  333. return list(outputs)
  334. class CallFunction(Expr):
  335. def __init__(self, func):
  336. super().__init__()
  337. assert isinstance(func, Callable)
  338. self.func = func
  339. self.const_val = []
  340. self.inputs = []
  341. @classmethod
  342. def make(cls, *args, **kwargs):
  343. expr = cls(*args, **kwargs)
  344. active_module_tracer().current_scope()._insert(expr)
  345. return expr
  346. def interpret(self, *inputs):
  347. args, kwargs = self.unflatten_args(inputs)
  348. outputs = self.func(*args, **kwargs)
  349. if outputs is None:
  350. return outputs
  351. outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
  352. return outputs
  353. def __repr__(self):
  354. args = ", ".join(str(i) for i in self.args)
  355. kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
  356. outputs = self.outputs
  357. if self.out_def:
  358. outputs = self.out_def.unflatten(outputs)
  359. return "%{}:\t{}{}({})".format(
  360. self._id,
  361. str(outputs) + " = " if outputs else "",
  362. self.func.__module__.rsplit(".")[-1] + "." + self.func.__name__,
  363. ", ".join([args, kwargs]),
  364. )
  365. # expr outputs = self.value
  366. class Constant(Expr):
  367. value = None
  368. # TODO: constant cache to reduce the size of dumped model
  369. _constant_cache = {}
  370. def __init__(self, c, name=None):
  371. super().__init__()
  372. assert isinstance(c, (RawTensor, Module))
  373. if isinstance(c, Module):
  374. assert module_tracer.is_builtin(c) or c.is_qat
  375. self.value = c
  376. self.name = name
  377. self.inputs = []
  378. node_cls = NodeMixin.get_wrapped_type(c)
  379. self.outputs = [
  380. node_cls(self, name=name, orig_name=name),
  381. ]
  382. self.outputs[0]._name = name if name else "const_" + str(self._id)
  383. @classmethod
  384. def make(cls, *args, **kwargs):
  385. expr = cls(*args, **kwargs)
  386. name = "const_module" if isinstance(expr.value, Module) else "const_tensor"
  387. full_name = name
  388. if (
  389. isinstance(expr.value, RawTensor)
  390. and id(expr.value) in active_module_tracer().id2name
  391. ):
  392. full_name = active_module_tracer().id2name[id(expr.value)]
  393. scope_name = active_module_tracer().current_scope()._module_name
  394. if full_name and scope_name:
  395. full_name = ("self." + full_name)[len(scope_name) + 1 :]
  396. else:
  397. full_name = name
  398. else:
  399. full_name = name
  400. name = active_module_tracer().current_scope()._create_unique_name(full_name)
  401. expr.outputs[0]._name = name
  402. expr.outputs[0]._orig_name = full_name
  403. active_module_tracer().current_scope()._insert(expr)
  404. return expr.outputs[0]
  405. def interpret(self, *inputs):
  406. if isinstance(self.value, RawTensor):
  407. return Const(self.value.numpy())()
  408. return (self.value,)
  409. def __repr__(self):
  410. name = self.name
  411. if name is None:
  412. name = type(self.value)
  413. node_type = "Module"
  414. if isinstance(self.outputs[0], TensorNode):
  415. node_type = "Tensor"
  416. return "%{}:\t{} = Constant({}) -> ({})".format(
  417. self._id, self.outputs[0], name, node_type
  418. )
  419. def __getstate__(self):
  420. state = self.__dict__.copy()
  421. if "_top_graph" in state:
  422. state.pop("_top_graph")
  423. if isinstance(self.value, RawTensor):
  424. state["value"] = Tensor(self.value)
  425. return state

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台