You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_node.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import abc
  10. import json
  11. import sys
  12. from typing import Callable, Sequence
  13. import numpy as np
  14. from ..core import _imperative_rt as rt
  15. from ..core._imperative_rt.core2 import SymbolVar
  16. from ..core._wrap import Device
  17. from ..core.ops import builtin
  18. from ..core.tensor.array_method import ArrayMethodMixin
  19. from ..core.tensor.indexing import getitem as _getitem
  20. from ..core.tensor.indexing import setitem as _setitem
  21. from ..core.tensor.megbrain_graph import InputNode, OutputNode
  22. from ..tensor import Tensor
  23. from .comp_graph_tools import replace_vars
  24. from .module_stats import (
  25. preprocess_receptive_field,
  26. register_flops,
  27. register_receptive_field,
  28. )
  29. class NetworkNode:
  30. pass
  31. class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)):
  32. pass
  33. class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
  34. def __init__(self, var=None, *, owner_opr=None, name=None):
  35. SymbolVar.__init__(self, var)
  36. self.owner = owner_opr
  37. self.name = name
  38. self.id = id(self)
  39. @classmethod
  40. def load(cls, sym_var, owner_opr):
  41. obj = cls()
  42. obj.var = sym_var # mgb varnode
  43. obj.name = sym_var.name
  44. obj.owner = owner_opr
  45. return obj
  46. @property
  47. def shape(self):
  48. rst = None
  49. if self.var:
  50. try:
  51. rst = self.var.shape
  52. except:
  53. rst = None
  54. return rst
  55. @property
  56. def dtype(self):
  57. return self.var.dtype if self.var else None
  58. def __bool__(self):
  59. return False
  60. __index__ = None
  61. __int__ = None
  62. __float__ = None
  63. __complex__ = None
  64. def __hash__(self):
  65. return id(self)
  66. @property
  67. def _tuple_shape(self):
  68. return self.var.shape
  69. def numpy(self):
  70. o = OutputNode(self.var)
  71. self.graph.compile(o.outputs).execute()
  72. return o.get_value().numpy()
  73. def __getitem__(self, index):
  74. return _getitem(self, index)
  75. def __setitem__(self, index, value):
  76. if index is not Ellipsis:
  77. value = _setitem(self, index, value)
  78. if self.owner is not None:
  79. idx = self.owner.outputs.index(self)
  80. self.owner.outputs[idx] = VarNode(
  81. self.var, owner_opr=self.owner, name=self.var.name
  82. )
  83. self.var = value.var
  84. self.owner = None
  85. def set_owner_opr(self, owner_opr):
  86. self.owner = owner_opr
  87. class OpNode(NetworkNode):
  88. opdef = None
  89. type = None
  90. def __init__(self):
  91. self.inputs = []
  92. self.outputs = []
  93. self.params = {}
  94. self._opr = None # mgb opnode
  95. self.id = id(self)
  96. @classmethod
  97. def load(cls, opr):
  98. obj = cls()
  99. obj.params = json.loads(opr.params)
  100. obj.name = opr.name
  101. obj._opr = opr
  102. return obj
  103. def compile(self, graph=None):
  104. op = self.opdef(**self.params)
  105. args = [i.var for i in self.inputs]
  106. outputs = rt.invoke_op(op, args)
  107. assert len(outputs) == len(self.outputs)
  108. self._opr = outputs[0].owner
  109. for i in range(len(self.outputs)):
  110. self.outputs[i].var = outputs[i]
  111. self.outputs[i].var.name = self.outputs[i].name
  112. assert self.outputs[i].owner is self
  113. def add_inp_var(self, x):
  114. self.inputs.append(x)
  115. def add_out_var(self, x):
  116. self.outputs.append(x)
  117. def __repr__(self):
  118. return "%s{%s}" % (self.name, self.type)
  119. def str_to_mge_class(classname):
  120. # TODO: use megbrain C++ RTTI to replace type string
  121. if classname == "RNGOpr<MegDNNOpr>":
  122. classname = "RNGOpr"
  123. oprcls = getattr(sys.modules[__name__], classname, None)
  124. return oprcls if oprcls else ReadOnlyOpNode
  125. class Host2DeviceCopy(OpNode):
  126. type = "Host2DeviceCopy"
  127. def __init__(self, shape=None, dtype=None, name=None, device=None):
  128. super().__init__()
  129. self.shape = shape
  130. self.dtype = dtype
  131. self.name = name
  132. self.device = Device(device).to_c() if device else Device("xpux").to_c()
  133. self.outputs = []
  134. @classmethod
  135. def load(cls, opr):
  136. self = cls()
  137. self.outputs = []
  138. assert len(opr.outputs) == 1, "wrong number of outputs"
  139. self.shape = opr.outputs[0].shape
  140. self.dtype = opr.outputs[0].dtype
  141. self.name = opr.outputs[0].name
  142. self.device = opr.outputs[0].comp_node
  143. self._opr = opr
  144. return self
  145. def compile(self, graph):
  146. outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name)
  147. self._opr = outputs.owner
  148. if len(self.outputs) == 0:
  149. self.outputs.append(VarNode(owner_opr=self, name=self.name))
  150. self.outputs[0].var = outputs
  151. assert self.outputs[0].owner is self
  152. class ImmutableTensor(OpNode):
  153. type = "ImmutableTensor"
  154. def __init__(self, data=None, name=None, device=None, graph=None):
  155. super().__init__()
  156. self.name = name
  157. self.outputs = []
  158. self.graph = graph
  159. if data is not None:
  160. self.set_value(data, device)
  161. @property
  162. def device(self):
  163. return self._opr.outputs[0].comp_node if self._opr else None
  164. @device.setter
  165. def device(self, device):
  166. self.set_value(self.numpy(), device)
  167. @property
  168. def shape(self):
  169. return self.outputs[0].shape
  170. @property
  171. def dtype(self):
  172. return self._opr.outputs[0].dtype if self._opr else None
  173. def numpy(self):
  174. return self._opr.outputs[0].value if self._opr else None
  175. def set_value(self, data, device=None):
  176. assert self.graph is not None
  177. cn = device if device else self.device
  178. assert isinstance(data, (int, float, Sequence, np.ndarray))
  179. if not isinstance(data, np.ndarray):
  180. data = np.array(data)
  181. if data.dtype == np.float64:
  182. data = data.astype(np.float32)
  183. elif data.dtype == np.int64:
  184. data = data.astype(np.int32)
  185. varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name)
  186. if len(self.outputs) == 0:
  187. self.outputs.append(VarNode(owner_opr=self, name=self.name))
  188. self.outputs[0].var = varnode
  189. self._opr = varnode.owner
  190. @classmethod
  191. def load(cls, opr):
  192. self = cls()
  193. self.outputs = []
  194. self._opr = opr
  195. self.name = opr.outputs[0].name
  196. self.graph = opr.graph
  197. return self
  198. def compile(self, graph):
  199. assert self.outputs[0].var is self._opr.outputs[0]
  200. assert self.outputs[0].owner is self
  201. if self.graph != graph:
  202. self.graph = graph
  203. self.set_value(self.numpy())
  204. if self.name is not None:
  205. self.outputs[0].var.name = self.name
  206. class ReadOnlyOpNode(OpNode):
  207. @classmethod
  208. def load(cls, opr):
  209. obj = super(ReadOnlyOpNode, cls).load(opr)
  210. obj.type = opr.type
  211. return obj
  212. def compile(self):
  213. assert self._opr is not None
  214. assert len(self.inputs) == len(self._opr.inputs)
  215. assert len(self.outputs) == len(self._opr.outputs)
  216. repl_dict = {}
  217. for ind, i in enumerate(self.inputs):
  218. if i.var != self._opr.inputs[ind]:
  219. repl_dict[self._opr.inputs[ind]] = i.var
  220. if bool(repl_dict):
  221. out_vars = replace_vars(self._opr.outputs, repl_dict)
  222. for ind, o in enumerate(self.outputs):
  223. o.var = out_vars[ind]
  224. class Elemwise(OpNode):
  225. type = "Elemwise"
  226. opdef = builtin.Elemwise
  227. def __repr__(self):
  228. return "%s{Elemwise:%s}" % (self.name, self.params["mode"])
  229. class ElemwiseMultiType(OpNode):
  230. type = "ElemwiseMultiType"
  231. opdef = builtin.ElemwiseMultiType
  232. def __repr__(self):
  233. return "%s{ElemwiseMultiType:%s}" % (self.name, self.params["mode"])
  234. @classmethod
  235. def load(cls, opr):
  236. obj = super(ElemwiseMultiType, cls).load(opr)
  237. obj.params["dtype"] = opr.outputs[0].dtype
  238. return obj
  239. @register_flops(Elemwise, ElemwiseMultiType)
  240. def flops_elemwise(opnode: Elemwise, inputs, outputs):
  241. return np.prod(outputs[0].shape)
  242. class Reduce(OpNode):
  243. type = "Reduce"
  244. opdef = builtin.Reduce
  245. class TypeCvt(OpNode):
  246. type = "TypeCvt"
  247. opdef = builtin.TypeCvt
  248. @classmethod
  249. def load(cls, opr):
  250. obj = super(TypeCvt, cls).load(opr)
  251. t_dtype = opr.outputs[0].dtype
  252. obj.params["dtype"] = t_dtype
  253. return obj
  254. class MatrixInverse(OpNode):
  255. type = "MatrixInverse"
  256. opdef = builtin.MatrixInverse
  257. class MatrixMul(OpNode):
  258. type = "MatrixMul"
  259. opdef = builtin.MatrixMul
  260. @register_flops(MatrixMul)
  261. def flops_matmul(opnode: MatrixMul, inputs, outputs):
  262. assert len(inputs[0].shape) == 2 and len(outputs[0].shape) == 2
  263. mid_shape = inputs[0].shape[1]
  264. return np.prod(outputs[0].shape) * mid_shape
  265. class BatchedMatrixMul(OpNode):
  266. type = "BatchedMatmul"
  267. opdef = builtin.BatchedMatrixMul
  268. @register_flops(BatchedMatrixMul)
  269. def flops_batchmatmul(opnode: BatchedMatrixMul, inputs, outputs):
  270. assert len(inputs[0].shape) == 3 and len(outputs[0].shape) == 3
  271. mid_shape = inputs[0].shape[2]
  272. return np.prod(outputs[0].shape) * mid_shape
  273. class Dot(OpNode):
  274. type = "Dot"
  275. opdef = builtin.Dot
  276. class SVD(OpNode):
  277. type = "SVD"
  278. opdef = builtin.SVD
  279. class ConvolutionForward(OpNode):
  280. type = "Convolution"
  281. opdef = builtin.Convolution
  282. class ConvolutionBackwardData(OpNode):
  283. type = "ConvTranspose"
  284. opdef = builtin.ConvolutionBackwardData
  285. class DeformableConvForward(OpNode):
  286. type = "DeformableConv"
  287. opdef = builtin.DeformableConv
  288. class GroupLocalForward(OpNode):
  289. type = "GroupLocal"
  290. opdef = builtin.GroupLocal
  291. class PoolingForward(OpNode):
  292. type = "Pooling"
  293. opdef = builtin.Pooling
  294. class AdaptivePoolingForward(OpNode):
  295. type = "AdaptivePooling"
  296. opdef = builtin.AdaptivePooling
  297. class ROIPoolingForward(OpNode):
  298. type = "ROIPooling"
  299. opdef = builtin.ROIPooling
  300. class DeformablePSROIPoolingForward(OpNode):
  301. type = "DeformablePSROIPooling"
  302. opdef = builtin.DeformablePSROIPooling
  303. class ConvBiasForward(OpNode):
  304. type = "ConvBias"
  305. opdef = builtin.ConvBias
  306. @classmethod
  307. def load(cls, opr):
  308. obj = super(ConvBiasForward, cls).load(opr)
  309. obj.params["dtype"] = opr.outputs[0].dtype
  310. return obj
  311. @register_flops(
  312. ConvolutionForward, ConvBiasForward,
  313. )
  314. def flops_conv(opnode: ConvolutionForward, inputs, outputs):
  315. param_W_shape = inputs[1].shape
  316. kh = param_W_shape[-2]
  317. kw = param_W_shape[-1]
  318. if len(param_W_shape) == 5:
  319. num_input = param_W_shape[2]
  320. else:
  321. num_input = param_W_shape[1]
  322. NCHW = np.prod(outputs[0].shape)
  323. bias = 1 if isinstance(opnode, ConvBiasForward) else 0
  324. # N x Cout x H x W x (Cin x Kw x Kh)
  325. return NCHW * (num_input * kw * kh + bias)
  326. @register_receptive_field(ConvolutionForward, ConvBiasForward)
  327. def receptive_field(opnode: ConvolutionForward, inputs, outputs):
  328. pre_rf, pre_stride = preprocess_receptive_field(opnode, inputs, outputs)
  329. param_W_shape = inputs[1].shape
  330. kh = param_W_shape[-2]
  331. kw = param_W_shape[-1]
  332. rf = (
  333. kh * pre_stride[0] + pre_rf[0] - pre_stride[0],
  334. kw * pre_stride[1] + pre_rf[1] - pre_stride[1],
  335. )
  336. stride = (
  337. opnode.params["stride_h"] * pre_stride[0],
  338. opnode.params["stride_w"] * pre_stride[1],
  339. )
  340. opnode._rf = rf
  341. opnode._stride = stride
  342. return rf, stride
  343. class BatchConvBiasForward(OpNode):
  344. type = "BatchConvBias"
  345. opdef = builtin.BatchConvBias
  346. @classmethod
  347. def load(cls, opr):
  348. obj = super(BatchConvBiasForward, cls).load(opr)
  349. obj.params["dtype"] = opr.outputs[0].dtype
  350. return obj
  351. class BatchNormForward(OpNode):
  352. type = "BatchNorm"
  353. opdef = builtin.BatchNorm
  354. output_idx = -1
  355. class ROIAlignForward(OpNode):
  356. type = "ROIAlign"
  357. opdef = builtin.ROIAlign
  358. class WarpPerspectiveForward(OpNode):
  359. type = "WarpPerspective"
  360. opdef = builtin.WarpPerspective
  361. class WarpAffineForward(OpNode):
  362. type = "WarpAffine"
  363. opdef = builtin.WarpAffine
  364. class RemapForward(OpNode):
  365. type = "Remap"
  366. opdef = builtin.Remap
  367. class ResizeForward(OpNode):
  368. type = "Resize"
  369. opdef = builtin.Resize
  370. class IndexingOneHot(OpNode):
  371. type = "IndexingOneHot"
  372. opdef = builtin.IndexingOneHot
  373. class IndexingSetOneHot(OpNode):
  374. type = "IndexingSetOneHot"
  375. opdef = builtin.IndexingSetOneHot
  376. class Copy(OpNode):
  377. type = "Copy"
  378. opdef = builtin.Copy
  379. @classmethod
  380. def load(cls, opr):
  381. obj = super(Copy, cls).load(opr)
  382. obj.params["comp_node"] = opr.outputs[0].comp_node
  383. return obj
  384. class ArgsortForward(OpNode):
  385. type = "Argsort"
  386. opdef = builtin.Argsort
  387. class Argmax(OpNode):
  388. type = "Argmax"
  389. opdef = builtin.Argmax
  390. class Argmin(OpNode):
  391. type = "Argmin"
  392. opdef = builtin.Argmin
  393. class CondTake(OpNode):
  394. type = "CondTake"
  395. opdef = builtin.CondTake
  396. class TopK(OpNode):
  397. type = "TopK"
  398. opdef = builtin.TopK
  399. class NvOf(OpNode):
  400. type = "NvOf"
  401. opdef = builtin.NvOf
  402. class RNGOpr(OpNode):
  403. @classmethod
  404. def load(cls, opr):
  405. obj = super(RNGOpr, cls).load(opr)
  406. if len(obj.params) == 3:
  407. obj.opdef = builtin.GaussianRNG
  408. obj.type = "GaussianRNG"
  409. else:
  410. obj.opdef = builtin.UniformRNG
  411. obj.type = "UniformRNG"
  412. return obj
  413. class Linspace(OpNode):
  414. type = "Linspace"
  415. opdef = builtin.Linspace
  416. @classmethod
  417. def load(cls, opr):
  418. obj = super(Linspace, cls).load(opr)
  419. obj.params["comp_node"] = opr.outputs[0].comp_node
  420. return obj
  421. class Eye(OpNode):
  422. type = "Eye"
  423. opdef = builtin.Eye
  424. @classmethod
  425. def load(cls, opr):
  426. obj = super(Eye, cls).load(opr)
  427. obj.params["dtype"] = opr.outputs[0].dtype
  428. obj.params["comp_node"] = opr.outputs[0].comp_node
  429. return obj
  430. class GetVarShape(OpNode):
  431. type = "GetVarShape"
  432. opdef = builtin.GetVarShape
  433. class Concat(OpNode):
  434. type = "Concat"
  435. opdef = builtin.Concat
  436. @classmethod
  437. def load(cls, opr):
  438. obj = super(Concat, cls).load(opr)
  439. obj.params["comp_node"] = Device("xpux").to_c()
  440. return obj
  441. class Broadcast(OpNode):
  442. type = "Broadcast"
  443. opdef = builtin.Broadcast
  444. class Identity(OpNode):
  445. type = "Identity"
  446. opdef = builtin.Identity
  447. class NMSKeep(OpNode):
  448. type = "NMSKeep"
  449. opdef = builtin.NMSKeep
  450. # class ParamPackSplit
  451. # class ParamPackConcat
  452. class Dimshuffle(OpNode):
  453. type = "Dimshuffle"
  454. opdef = builtin.Dimshuffle
  455. @classmethod
  456. def load(cls, opr):
  457. obj = super(Dimshuffle, cls).load(opr)
  458. del obj.params["ndim"]
  459. return obj
  460. class Reshape(OpNode):
  461. type = "Reshape"
  462. opdef = builtin.Reshape
  463. class AxisAddRemove(OpNode):
  464. type = "AxisAddRemove"
  465. @classmethod
  466. def load(cls, opr):
  467. obj = cls()
  468. obj.name = opr.name
  469. obj._opr = opr
  470. params = json.loads(opr.params)
  471. desc = params["desc"]
  472. method = None
  473. axis = []
  474. for i in desc:
  475. if method is None:
  476. method = i["method"]
  477. assert method == i["method"]
  478. axis.append(i["axisnum"])
  479. obj.params = {"axis": axis}
  480. obj.opdef = builtin.AddAxis if desc[0]["method"] == 0 else builtin.RemoveAxis
  481. return obj
  482. class IndexingBase(OpNode):
  483. @classmethod
  484. def load(cls, opr):
  485. obj = cls()
  486. obj.name = opr.name
  487. obj._opr = opr
  488. params = json.loads(opr.params)
  489. items = [
  490. [
  491. p["axis"],
  492. bool(p["begin"]),
  493. bool(p["end"]),
  494. bool(p["step"]),
  495. bool(p["idx"]),
  496. ]
  497. for p in params
  498. ]
  499. obj.params["items"] = items
  500. return obj
  501. class Subtensor(IndexingBase):
  502. type = "Subtensor"
  503. opdef = builtin.Subtensor
  504. class SetSubtensor(IndexingBase):
  505. type = "SetSubtensor"
  506. opdef = builtin.SetSubtensor
  507. class IncrSubtensor(IndexingBase):
  508. type = "IncrSubtensor"
  509. opdef = builtin.IncrSubtensor
  510. class IndexingMultiAxisVec(IndexingBase):
  511. type = "IndexingMultiAxisVec"
  512. opdef = builtin.IndexingMultiAxisVec
  513. class IndexingSetMultiAxisVec(IndexingBase):
  514. type = "IndexingSetMultiAxisVec"
  515. opdef = builtin.IndexingSetMultiAxisVec
  516. class IndexingIncrMultiAxisVec(IndexingBase):
  517. type = "IndexingIncrMultiAxisVec"
  518. opdef = builtin.IndexingIncrMultiAxisVec
  519. class MeshIndexing(IndexingBase):
  520. type = "MeshIndexing"
  521. opdef = builtin.MeshIndexing
  522. class SetMeshIndexing(IndexingBase):
  523. type = "SetMeshIndexing"
  524. opdef = builtin.SetMeshIndexing
  525. class IncrMeshIndexing(IndexingBase):
  526. type = "IncrMeshIndexing"
  527. opdef = builtin.IncrMeshIndexing
  528. class BatchedMeshIndexing(IndexingBase):
  529. type = "BatchedMeshIndexing"
  530. opdef = builtin.BatchedMeshIndexing
  531. class BatchedSetMeshIndexing(IndexingBase):
  532. type = "BatchedSetMeshIndexing"
  533. opdef = builtin.BatchedSetMeshIndexing
  534. class BatchedIncrMeshIndexing(IndexingBase):
  535. type = "BatchedIncrMeshIndexing"
  536. opdef = builtin.BatchedIncrMeshIndexing
  537. # class CollectiveComm
  538. # class RemoteSend
  539. # class RemoteRecv
  540. # class TQT
  541. # class FakeQuant
  542. # class InplaceAdd
  543. class AssertEqual(OpNode):
  544. type = "AssertEqual"
  545. opdef = builtin.AssertEqual
  546. class CvtColorForward(OpNode):
  547. type = "CvtColor"
  548. opdef = builtin.CvtColor

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台