You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_node.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import abc
  10. import json
  11. import sys
  12. from typing import Callable, Sequence
  13. import numpy as np
  14. from ..core import _imperative_rt as rt
  15. from ..core._imperative_rt.core2 import SymbolVar
  16. from ..core._wrap import Device
  17. from ..core.ops import builtin
  18. from ..core.tensor.array_method import ArrayMethodMixin
  19. from ..core.tensor.indexing import getitem as _getitem
  20. from ..core.tensor.indexing import setitem as _setitem
  21. from ..core.tensor.megbrain_graph import InputNode, OutputNode
  22. from ..tensor import Tensor
  23. from .comp_graph_tools import replace_vars
  24. from .module_stats import (
  25. preprocess_receptive_field,
  26. register_flops,
  27. register_receptive_field,
  28. )
  29. class NetworkNode:
  30. pass
  31. class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)):
  32. pass
  33. class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
  34. def __init__(self, var=None, *, owner_opr=None, name=None):
  35. SymbolVar.__init__(self, var)
  36. self.owner = owner_opr
  37. self.name = name
  38. self.id = id(self)
  39. @classmethod
  40. def load(cls, sym_var, owner_opr):
  41. obj = cls()
  42. obj.var = sym_var # mgb varnode
  43. obj.name = sym_var.name
  44. obj.owner = owner_opr
  45. return obj
  46. @property
  47. def shape(self):
  48. rst = None
  49. if self.var:
  50. try:
  51. rst = self.var.shape
  52. except:
  53. rst = None
  54. return rst
  55. @property
  56. def dtype(self):
  57. return self.var.dtype if self.var else None
  58. def __bool__(self):
  59. return False
  60. __index__ = None
  61. __int__ = None
  62. __float__ = None
  63. __complex__ = None
  64. def __hash__(self):
  65. return id(self)
  66. @property
  67. def _tuple_shape(self):
  68. return self.var.shape
  69. def numpy(self):
  70. o = OutputNode(self.var)
  71. self.graph.compile(o.outputs).execute()
  72. return o.get_value().numpy()
  73. def __getitem__(self, index):
  74. return _getitem(self, index)
  75. def __setitem__(self, index, value):
  76. if index is not Ellipsis:
  77. value = _setitem(self, index, value)
  78. if self.owner is not None:
  79. idx = self.owner.outputs.index(self)
  80. self.owner.outputs[idx] = VarNode(
  81. self.var, owner_opr=self.owner, name=self.var.name
  82. )
  83. self.var = value.var
  84. self.owner = None
  85. def set_owner_opr(self, owner_opr):
  86. self.owner = owner_opr
  87. class OpNode(NetworkNode):
  88. opdef = None
  89. type = None
  90. def __init__(self):
  91. self.inputs = []
  92. self.outputs = []
  93. self.params = {}
  94. self._opr = None # mgb opnode
  95. self.id = id(self)
  96. @classmethod
  97. def load(cls, opr):
  98. obj = cls()
  99. obj.params = json.loads(opr.params)
  100. obj.name = opr.name
  101. obj._opr = opr
  102. return obj
  103. def compile(self, graph=None):
  104. op = self.opdef(**self.params)
  105. args = [i.var for i in self.inputs]
  106. outputs = rt.invoke_op(op, args)
  107. assert len(outputs) == len(self.outputs)
  108. self._opr = outputs[0].owner
  109. for i in range(len(self.outputs)):
  110. self.outputs[i].var = outputs[i]
  111. self.outputs[i].var.name = self.outputs[i].name
  112. assert self.outputs[i].owner is self
  113. def add_inp_var(self, x):
  114. self.inputs.append(x)
  115. def add_out_var(self, x):
  116. self.outputs.append(x)
  117. def str_to_mge_class(classname):
  118. # TODO: use megbrain C++ RTTI to replace type string
  119. if classname == "RNGOpr<MegDNNOpr>":
  120. classname = "RNGOpr"
  121. oprcls = getattr(sys.modules[__name__], classname, None)
  122. return oprcls if oprcls else ReadOnlyOpNode
  123. class Host2DeviceCopy(OpNode):
  124. type = "Host2DeviceCopy"
  125. def __init__(self, shape=None, dtype=None, name=None, device=None):
  126. super().__init__()
  127. self.shape = shape
  128. self.dtype = dtype
  129. self.name = name
  130. self.device = Device(device).to_c() if device else Device("xpux").to_c()
  131. self.outputs = []
  132. @classmethod
  133. def load(cls, opr):
  134. self = cls()
  135. self.outputs = []
  136. assert len(opr.outputs) == 1, "wrong number of outputs"
  137. self.shape = opr.outputs[0].shape
  138. self.dtype = opr.outputs[0].dtype
  139. self.name = opr.outputs[0].name
  140. self.device = opr.outputs[0].comp_node
  141. self._opr = opr
  142. return self
  143. def compile(self, graph):
  144. outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name)
  145. self._opr = outputs.owner
  146. if len(self.outputs) == 0:
  147. self.outputs.append(VarNode(owner_opr=self, name=self.name))
  148. self.outputs[0].var = outputs
  149. assert self.outputs[0].owner is self
  150. class ImmutableTensor(OpNode):
  151. type = "ImmutableTensor"
  152. def __init__(self, data=None, name=None, device=None, graph=None):
  153. super().__init__()
  154. self.name = name
  155. self.outputs = []
  156. self.graph = graph
  157. if data is not None:
  158. self.set_value(data, device)
  159. @property
  160. def device(self):
  161. return self._opr.outputs[0].comp_node if self._opr else None
  162. @device.setter
  163. def device(self, device):
  164. self.set_value(self.numpy(), device)
  165. @property
  166. def shape(self):
  167. return self.outputs[0].shape
  168. @property
  169. def dtype(self):
  170. return self._opr.outputs[0].dtype if self._opr else None
  171. def numpy(self):
  172. return self._opr.outputs[0].value if self._opr else None
  173. def set_value(self, data, device=None):
  174. assert self.graph is not None
  175. cn = device if device else self.device
  176. assert isinstance(data, (int, float, Sequence, np.ndarray))
  177. if not isinstance(data, np.ndarray):
  178. data = np.array(data)
  179. if data.dtype == np.float64:
  180. data = data.astype(np.float32)
  181. elif data.dtype == np.int64:
  182. data = data.astype(np.int32)
  183. varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name)
  184. if len(self.outputs) == 0:
  185. self.outputs.append(VarNode(owner_opr=self, name=self.name))
  186. self.outputs[0].var = varnode
  187. self._opr = varnode.owner
  188. @classmethod
  189. def load(cls, opr):
  190. self = cls()
  191. self.outputs = []
  192. self._opr = opr
  193. self.name = opr.outputs[0].name
  194. self.graph = opr.graph
  195. return self
  196. def compile(self, graph):
  197. assert self.outputs[0].var is self._opr.outputs[0]
  198. assert self.outputs[0].owner is self
  199. if self.graph != graph:
  200. self.graph = graph
  201. self.set_value(self.numpy())
  202. if self.name is not None:
  203. self.outputs[0].var.name = self.name
  204. class ReadOnlyOpNode(OpNode):
  205. @classmethod
  206. def load(cls, opr):
  207. obj = super(ReadOnlyOpNode, cls).load(opr)
  208. obj.type = opr.type
  209. return obj
  210. def compile(self):
  211. assert self._opr is not None
  212. assert len(self.inputs) == len(self._opr.inputs)
  213. assert len(self.outputs) == len(self._opr.outputs)
  214. repl_dict = {}
  215. for ind, i in enumerate(self.inputs):
  216. if i.var != self._opr.inputs[ind]:
  217. repl_dict[self._opr.inputs[ind]] = i.var
  218. if bool(repl_dict):
  219. out_vars = replace_vars(self._opr.outputs, repl_dict)
  220. for ind, o in enumerate(self.outputs):
  221. o.var = out_vars[ind]
  222. class Elemwise(OpNode):
  223. type = "Elemwise"
  224. opdef = builtin.Elemwise
  225. class ElemwiseMultiType(OpNode):
  226. type = "ElemwiseMultiType"
  227. opdef = builtin.ElemwiseMultiType
  228. @classmethod
  229. def load(cls, opr):
  230. obj = super(ElemwiseMultiType, cls).load(opr)
  231. obj.params["dtype"] = opr.outputs[0].dtype
  232. return obj
  233. @register_flops(Elemwise, ElemwiseMultiType)
  234. def flops_elemwise(opnode: Elemwise, inputs, outputs):
  235. return np.prod(outputs[0].shape)
  236. class Reduce(OpNode):
  237. type = "Reduce"
  238. opdef = builtin.Reduce
  239. class TypeCvt(OpNode):
  240. type = "TypeCvt"
  241. opdef = builtin.TypeCvt
  242. @classmethod
  243. def load(cls, opr):
  244. obj = super(TypeCvt, cls).load(opr)
  245. t_dtype = opr.outputs[0].dtype
  246. obj.params["dtype"] = t_dtype
  247. return obj
  248. class MatrixInverse(OpNode):
  249. type = "MatrixInverse"
  250. opdef = builtin.MatrixInverse
  251. class MatrixMul(OpNode):
  252. type = "MatrixMul"
  253. opdef = builtin.MatrixMul
  254. @register_flops(MatrixMul)
  255. def flops_matmul(opnode: MatrixMul, inputs, outputs):
  256. assert len(inputs[0].shape) == 2 and len(outputs[0].shape) == 2
  257. mid_shape = inputs[0].shape[1]
  258. return np.prod(outputs[0].shape) * mid_shape
  259. class BatchedMatrixMul(OpNode):
  260. type = "BatchedMatmul"
  261. opdef = builtin.BatchedMatrixMul
  262. @register_flops(BatchedMatrixMul)
  263. def flops_batchmatmul(opnode: BatchedMatrixMul, inputs, outputs):
  264. assert len(inputs[0].shape) == 3 and len(outputs[0].shape) == 3
  265. mid_shape = inputs[0].shape[2]
  266. return np.prod(outputs[0].shape) * mid_shape
  267. class Dot(OpNode):
  268. type = "Dot"
  269. opdef = builtin.Dot
  270. class SVD(OpNode):
  271. type = "SVD"
  272. opdef = builtin.SVD
  273. class ConvolutionForward(OpNode):
  274. type = "Convolution"
  275. opdef = builtin.Convolution
  276. class ConvolutionBackwardData(OpNode):
  277. type = "ConvTranspose"
  278. opdef = builtin.ConvolutionBackwardData
  279. class DeformableConvForward(OpNode):
  280. type = "DeformableConv"
  281. opdef = builtin.DeformableConv
  282. class GroupLocalForward(OpNode):
  283. type = "GroupLocal"
  284. opdef = builtin.GroupLocal
  285. class PoolingForward(OpNode):
  286. type = "Pooling"
  287. opdef = builtin.Pooling
  288. class AdaptivePoolingForward(OpNode):
  289. type = "AdaptivePooling"
  290. opdef = builtin.AdaptivePooling
  291. class ROIPoolingForward(OpNode):
  292. type = "ROIPooling"
  293. opdef = builtin.ROIPooling
  294. class DeformablePSROIPoolingForward(OpNode):
  295. type = "DeformablePSROIPooling"
  296. opdef = builtin.DeformablePSROIPooling
  297. class ConvBiasForward(OpNode):
  298. type = "ConvBias"
  299. opdef = builtin.ConvBias
  300. @classmethod
  301. def load(cls, opr):
  302. obj = super(ConvBiasForward, cls).load(opr)
  303. obj.params["dtype"] = opr.outputs[0].dtype
  304. return obj
  305. @register_flops(
  306. ConvolutionForward, ConvBiasForward,
  307. )
  308. def flops_conv(opnode: ConvolutionForward, inputs, outputs):
  309. param_W_shape = inputs[1].shape
  310. kh = param_W_shape[-2]
  311. kw = param_W_shape[-1]
  312. if len(param_W_shape) == 5:
  313. num_input = param_W_shape[2]
  314. else:
  315. num_input = param_W_shape[1]
  316. NCHW = np.prod(outputs[0].shape)
  317. bias = 1 if isinstance(opnode, ConvBiasForward) else 0
  318. # N x Cout x H x W x (Cin x Kw x Kh)
  319. return NCHW * (num_input * kw * kh + bias)
  320. @register_receptive_field(ConvolutionForward, ConvBiasForward)
  321. def receptive_field(opnode: ConvolutionForward, inputs, outputs):
  322. pre_rf, pre_stride = preprocess_receptive_field(opnode, inputs, outputs)
  323. param_W_shape = inputs[1].shape
  324. kh = param_W_shape[-2]
  325. kw = param_W_shape[-1]
  326. rf = (
  327. kh * pre_stride[0] + pre_rf[0] - pre_stride[0],
  328. kw * pre_stride[1] + pre_rf[1] - pre_stride[1],
  329. )
  330. stride = (
  331. opnode.params["stride_h"] * pre_stride[0],
  332. opnode.params["stride_w"] * pre_stride[1],
  333. )
  334. opnode._rf = rf
  335. opnode._stride = stride
  336. return rf, stride
  337. class BatchConvBiasForward(OpNode):
  338. type = "BatchConvBias"
  339. opdef = builtin.BatchConvBias
  340. @classmethod
  341. def load(cls, opr):
  342. obj = super(BatchConvBiasForward, cls).load(opr)
  343. obj.params["dtype"] = opr.outputs[0].dtype
  344. return obj
  345. class BatchNormForward(OpNode):
  346. type = "BatchNorm"
  347. opdef = builtin.BatchNorm
  348. output_idx = -1
  349. class ROIAlignForward(OpNode):
  350. type = "ROIAlign"
  351. opdef = builtin.ROIAlign
  352. class WarpPerspectiveForward(OpNode):
  353. type = "WarpPerspective"
  354. opdef = builtin.WarpPerspective
  355. class WarpAffineForward(OpNode):
  356. type = "WarpAffine"
  357. opdef = builtin.WarpAffine
  358. class RemapForward(OpNode):
  359. type = "Remap"
  360. opdef = builtin.Remap
  361. class ResizeForward(OpNode):
  362. type = "Resize"
  363. opdef = builtin.Resize
  364. class IndexingOneHot(OpNode):
  365. type = "IndexingOneHot"
  366. opdef = builtin.IndexingOneHot
  367. class IndexingSetOneHot(OpNode):
  368. type = "IndexingSetOneHot"
  369. opdef = builtin.IndexingSetOneHot
  370. class Copy(OpNode):
  371. type = "Copy"
  372. opdef = builtin.Copy
  373. @classmethod
  374. def load(cls, opr):
  375. obj = super(Copy, cls).load(opr)
  376. obj.params["comp_node"] = opr.outputs[0].comp_node
  377. return obj
  378. class ArgsortForward(OpNode):
  379. type = "Argsort"
  380. opdef = builtin.Argsort
  381. class Argmax(OpNode):
  382. type = "Argmax"
  383. opdef = builtin.Argmax
  384. class Argmin(OpNode):
  385. type = "Argmin"
  386. opdef = builtin.Argmin
  387. class CondTake(OpNode):
  388. type = "CondTake"
  389. opdef = builtin.CondTake
  390. class TopK(OpNode):
  391. type = "TopK"
  392. opdef = builtin.TopK
  393. class NvOf(OpNode):
  394. type = "NvOf"
  395. opdef = builtin.NvOf
  396. class RNGOpr(OpNode):
  397. @classmethod
  398. def load(cls, opr):
  399. obj = super(RNGOpr, cls).load(opr)
  400. if len(obj.params) == 3:
  401. obj.opdef = builtin.GaussianRNG
  402. obj.type = "GaussianRNG"
  403. else:
  404. obj.opdef = builtin.UniformRNG
  405. obj.type = "UniformRNG"
  406. return obj
  407. class Linspace(OpNode):
  408. type = "Linspace"
  409. opdef = builtin.Linspace
  410. @classmethod
  411. def load(cls, opr):
  412. obj = super(Linspace, cls).load(opr)
  413. obj.params["comp_node"] = opr.outputs[0].comp_node
  414. return obj
  415. class Eye(OpNode):
  416. type = "Eye"
  417. opdef = builtin.Eye
  418. @classmethod
  419. def load(cls, opr):
  420. obj = super(Eye, cls).load(opr)
  421. obj.params["dtype"] = opr.outputs[0].dtype
  422. obj.params["comp_node"] = opr.outputs[0].comp_node
  423. return obj
  424. class GetVarShape(OpNode):
  425. type = "GetVarShape"
  426. opdef = builtin.GetVarShape
  427. class Concat(OpNode):
  428. type = "Concat"
  429. opdef = builtin.Concat
  430. @classmethod
  431. def load(cls, opr):
  432. obj = super(Concat, cls).load(opr)
  433. obj.params["comp_node"] = Device("xpux").to_c()
  434. return obj
  435. class Broadcast(OpNode):
  436. type = "Broadcast"
  437. opdef = builtin.Broadcast
  438. class Identity(OpNode):
  439. type = "Identity"
  440. opdef = builtin.Identity
  441. class NMSKeep(OpNode):
  442. type = "NMSKeep"
  443. opdef = builtin.NMSKeep
  444. # class ParamPackSplit
  445. # class ParamPackConcat
  446. class Dimshuffle(OpNode):
  447. type = "Dimshuffle"
  448. opdef = builtin.Dimshuffle
  449. @classmethod
  450. def load(cls, opr):
  451. obj = super(Dimshuffle, cls).load(opr)
  452. del obj.params["ndim"]
  453. return obj
  454. class Reshape(OpNode):
  455. type = "Reshape"
  456. opdef = builtin.Reshape
  457. class AxisAddRemove(OpNode):
  458. type = "AxisAddRemove"
  459. @classmethod
  460. def load(cls, opr):
  461. obj = cls()
  462. obj.name = opr.name
  463. obj._opr = opr
  464. params = json.loads(opr.params)
  465. desc = params["desc"]
  466. method = None
  467. axis = []
  468. for i in desc:
  469. if method is None:
  470. method = i["method"]
  471. assert method == i["method"]
  472. axis.append(i["axisnum"])
  473. obj.params = {"axis": axis}
  474. obj.opdef = builtin.AddAxis if desc[0]["method"] == 0 else builtin.RemoveAxis
  475. return obj
  476. class IndexingBase(OpNode):
  477. @classmethod
  478. def load(cls, opr):
  479. obj = cls()
  480. obj.name = opr.name
  481. obj._opr = opr
  482. params = json.loads(opr.params)
  483. items = [
  484. [
  485. p["axis"],
  486. bool(p["begin"]),
  487. bool(p["end"]),
  488. bool(p["step"]),
  489. bool(p["idx"]),
  490. ]
  491. for p in params
  492. ]
  493. obj.params["items"] = items
  494. return obj
  495. class Subtensor(IndexingBase):
  496. type = "Subtensor"
  497. opdef = builtin.Subtensor
  498. class SetSubtensor(IndexingBase):
  499. type = "SetSubtensor"
  500. opdef = builtin.SetSubtensor
  501. class IncrSubtensor(IndexingBase):
  502. type = "IncrSubtensor"
  503. opdef = builtin.IncrSubtensor
  504. class IndexingMultiAxisVec(IndexingBase):
  505. type = "IndexingMultiAxisVec"
  506. opdef = builtin.IndexingMultiAxisVec
  507. class IndexingSetMultiAxisVec(IndexingBase):
  508. type = "IndexingSetMultiAxisVec"
  509. opdef = builtin.IndexingSetMultiAxisVec
  510. class IndexingIncrMultiAxisVec(IndexingBase):
  511. type = "IndexingIncrMultiAxisVec"
  512. opdef = builtin.IndexingIncrMultiAxisVec
  513. class MeshIndexing(IndexingBase):
  514. type = "MeshIndexing"
  515. opdef = builtin.MeshIndexing
  516. class SetMeshIndexing(IndexingBase):
  517. type = "SetMeshIndexing"
  518. opdef = builtin.SetMeshIndexing
  519. class IncrMeshIndexing(IndexingBase):
  520. type = "IncrMeshIndexing"
  521. opdef = builtin.IncrMeshIndexing
  522. class BatchedMeshIndexing(IndexingBase):
  523. type = "BatchedMeshIndexing"
  524. opdef = builtin.BatchedMeshIndexing
  525. class BatchedSetMeshIndexing(IndexingBase):
  526. type = "BatchedSetMeshIndexing"
  527. opdef = builtin.BatchedSetMeshIndexing
  528. class BatchedIncrMeshIndexing(IndexingBase):
  529. type = "BatchedIncrMeshIndexing"
  530. opdef = builtin.BatchedIncrMeshIndexing
  531. # class CollectiveComm
  532. # class RemoteSend
  533. # class RemoteRecv
  534. # class TQT
  535. # class FakeQuant
  536. # class InplaceAdd
  537. class AssertEqual(OpNode):
  538. type = "AssertEqual"
  539. opdef = builtin.AssertEqual
  540. class CvtColorForward(OpNode):
  541. type = "CvtColor"
  542. opdef = builtin.CvtColor

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台