You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_node.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import json
  10. import sys
  11. from typing import Callable
  12. import numpy as np
  13. from ..core import _imperative_rt as rt
  14. from ..core._wrap import Device
  15. from ..core.ops import builtin
  16. from ..core.tensor.megbrain_graph import InputNode
  17. from ..tensor import Tensor
  18. from .comp_graph_tools import replace_vars
  19. from .module_stats import (
  20. preprocess_receptive_field,
  21. register_flops,
  22. register_receptive_field,
  23. )
  24. class NetworkNode:
  25. pass
  26. class VarNode(NetworkNode):
  27. def __init__(self, owner_opr=None, name=None):
  28. self.var = None
  29. self.owner = owner_opr
  30. self.name = name
  31. self.id = id(self)
  32. @classmethod
  33. def load(cls, sym_var, owner_opr):
  34. obj = cls()
  35. obj.var = sym_var # mgb varnode
  36. obj.name = sym_var.name
  37. obj.owner = owner_opr
  38. return obj
  39. @property
  40. def shape(self):
  41. rst = None
  42. if self.var:
  43. try:
  44. rst = self.var.shape
  45. except:
  46. rst = None
  47. return rst
  48. @property
  49. def dtype(self):
  50. return self.var.dtype if self.var else None
  51. def set_owner_opr(self, owner_opr):
  52. self.owner = owner_opr
  53. class OpNode(NetworkNode):
  54. opdef = None
  55. type = None
  56. def __init__(self):
  57. self.inputs = []
  58. self.outputs = []
  59. self.params = {}
  60. self._opr = None # mgb opnode
  61. self.id = id(self)
  62. @classmethod
  63. def load(cls, opr):
  64. obj = cls()
  65. obj.params = json.loads(opr.params)
  66. obj.name = opr.name
  67. obj._opr = opr
  68. return obj
  69. def compile(self, graph=None):
  70. op = self.opdef(**self.params)
  71. args = [i.var for i in self.inputs]
  72. outputs = rt.invoke_op(op, args)
  73. assert len(outputs) == len(self.outputs)
  74. self._opr = outputs[0].owner
  75. for i in range(len(self.outputs)):
  76. self.outputs[i].var = outputs[i]
  77. self.outputs[i].var.name = self.outputs[i].name
  78. assert self.outputs[i].owner is self
  79. def add_inp_var(self, x):
  80. self.inputs.append(x)
  81. def add_out_var(self, x):
  82. self.outputs.append(x)
  83. def str_to_mge_class(classname):
  84. # TODO: use megbrain C++ RTTI to replace type string
  85. if classname == "RNGOpr<MegDNNOpr>":
  86. classname = "RNGOpr"
  87. oprcls = getattr(sys.modules[__name__], classname, None)
  88. return oprcls if oprcls else ReadOnlyOpNode
  89. class Host2DeviceCopy(OpNode):
  90. type = "Host2DeviceCopy"
  91. def __init__(self, shape=None, dtype=None, name=None, device=None):
  92. super().__init__()
  93. self.shape = shape
  94. self.dtype = dtype
  95. self.name = name
  96. self.device = Device(device).to_c() if device else Device("xpux").to_c()
  97. self.outputs = []
  98. @classmethod
  99. def load(cls, opr):
  100. self = cls()
  101. self.outputs = []
  102. assert len(opr.outputs) == 1, "wrong number of outputs"
  103. self.shape = opr.outputs[0].shape
  104. self.dtype = opr.outputs[0].dtype
  105. self.name = opr.outputs[0].name
  106. self.device = opr.outputs[0].comp_node
  107. self._opr = opr
  108. return self
  109. def compile(self, graph):
  110. outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name)
  111. self._opr = outputs.owner
  112. if len(self.outputs) == 0:
  113. self.outputs.append(VarNode(self, self.name))
  114. self.outputs[0].var = outputs
  115. assert self.outputs[0].owner is self
  116. class ImmutableTensor(OpNode):
  117. type = "ImmutableTensor"
  118. def __init__(self, data=None, name=None, device=None, graph=None):
  119. super().__init__()
  120. self.name = name
  121. self.outputs = []
  122. self.graph = graph
  123. if data is not None:
  124. self.set_value(data, device)
  125. @property
  126. def device(self):
  127. return self._opr.outputs[0].comp_node if self._opr else None
  128. @device.setter
  129. def device(self, device):
  130. self.set_value(self.numpy(), device)
  131. @property
  132. def shape(self):
  133. return self.outputs[0].shape
  134. @property
  135. def dtype(self):
  136. return self._opr.outputs[0].dtype if self._opr else None
  137. def numpy(self):
  138. return self._opr.outputs[0].value if self._opr else None
  139. def set_value(self, data, device=None):
  140. assert self.graph is not None
  141. cn = device if device else self.device
  142. assert isinstance(data, (int, float, np.ndarray))
  143. if isinstance(data, (int, float)):
  144. data = np.array(data)
  145. if data.dtype == np.float64:
  146. data = data.astype(np.float32)
  147. elif data.dtype == np.int64:
  148. data = data.astype(np.int32)
  149. varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name)
  150. if len(self.outputs) == 0:
  151. self.outputs.append(VarNode(self, self.name))
  152. self.outputs[0].var = varnode
  153. self._opr = varnode.owner
  154. @classmethod
  155. def load(cls, opr):
  156. self = cls()
  157. self.outputs = []
  158. self._opr = opr
  159. self.name = opr.outputs[0].name
  160. self.graph = opr.graph
  161. return self
  162. def compile(self, graph):
  163. assert self.outputs[0].var is self._opr.outputs[0]
  164. assert self.outputs[0].owner is self
  165. if self.graph != graph:
  166. self.graph = graph
  167. self.set_value(self.numpy())
  168. if self.name is not None:
  169. self.outputs[0].var.name = self.name
  170. class ReadOnlyOpNode(OpNode):
  171. @classmethod
  172. def load(cls, opr):
  173. obj = super(ReadOnlyOpNode, cls).load(opr)
  174. obj.type = opr.type
  175. return obj
  176. def compile(self):
  177. assert self._opr is not None
  178. assert len(self.inputs) == len(self._opr.inputs)
  179. assert len(self.outputs) == len(self._opr.outputs)
  180. repl_dict = {}
  181. for ind, i in enumerate(self.inputs):
  182. if i.var != self._opr.inputs[ind]:
  183. repl_dict[self._opr.inputs[ind]] = i.var
  184. if bool(repl_dict):
  185. out_vars = replace_vars(self._opr.outputs, repl_dict)
  186. for ind, o in enumerate(self.outputs):
  187. o.var = out_vars[ind]
  188. class Elemwise(OpNode):
  189. type = "Elemwise"
  190. opdef = builtin.Elemwise
  191. class ElemwiseMultiType(OpNode):
  192. type = "ElemwiseMultiType"
  193. opdef = builtin.ElemwiseMultiType
  194. @classmethod
  195. def load(cls, opr):
  196. obj = super(ElemwiseMultiType, cls).load(opr)
  197. obj.params["dtype"] = opr.outputs[0].dtype
  198. return obj
  199. @register_flops(Elemwise, ElemwiseMultiType)
  200. def flops_elemwise(opnode: Elemwise, inputs, outputs):
  201. return np.prod(outputs[0].shape)
  202. class Reduce(OpNode):
  203. type = "Reduce"
  204. opdef = builtin.Reduce
  205. class TypeCvt(OpNode):
  206. type = "TypeCvt"
  207. opdef = builtin.TypeCvt
  208. @classmethod
  209. def load(cls, opr):
  210. obj = super(TypeCvt, cls).load(opr)
  211. t_dtype = opr.outputs[0].dtype
  212. obj.params["dtype"] = t_dtype
  213. return obj
  214. class MatrixInverse(OpNode):
  215. type = "MatrixInverse"
  216. opdef = builtin.MatrixInverse
  217. class MatrixMul(OpNode):
  218. type = "MatrixMul"
  219. opdef = builtin.MatrixMul
  220. @register_flops(MatrixMul)
  221. def flops_matmul(opnode: MatrixMul, inputs, outputs):
  222. assert len(inputs[0].shape) == 2 and len(outputs[0].shape) == 2
  223. mid_shape = inputs[0].shape[1]
  224. return np.prod(outputs[0].shape) * mid_shape
  225. class BatchedMatrixMul(OpNode):
  226. type = "BatchedMatmul"
  227. opdef = builtin.BatchedMatrixMul
  228. @register_flops(BatchedMatrixMul)
  229. def flops_batchmatmul(opnode: BatchedMatrixMul, inputs, outputs):
  230. assert len(inputs[0].shape) == 3 and len(outputs[0].shape) == 3
  231. mid_shape = inputs[0].shape[2]
  232. return np.prod(outputs[0].shape) * mid_shape
  233. class Dot(OpNode):
  234. type = "Dot"
  235. opdef = builtin.Dot
  236. class SVD(OpNode):
  237. type = "SVD"
  238. opdef = builtin.SVD
  239. class ConvolutionForward(OpNode):
  240. type = "Convolution"
  241. opdef = builtin.Convolution
  242. class ConvolutionBackwardData(OpNode):
  243. type = "ConvTranspose"
  244. opdef = builtin.ConvolutionBackwardData
  245. class DeformableConvForward(OpNode):
  246. type = "DeformableConv"
  247. opdef = builtin.DeformableConv
  248. class GroupLocalForward(OpNode):
  249. type = "GroupLocal"
  250. opdef = builtin.GroupLocal
  251. class PoolingForward(OpNode):
  252. type = "Pooling"
  253. opdef = builtin.Pooling
  254. class AdaptivePoolingForward(OpNode):
  255. type = "AdaptivePooling"
  256. opdef = builtin.AdaptivePooling
  257. class ROIPoolingForward(OpNode):
  258. type = "ROIPooling"
  259. opdef = builtin.ROIPooling
  260. class DeformablePSROIPoolingForward(OpNode):
  261. type = "DeformablePSROIPooling"
  262. opdef = builtin.DeformablePSROIPooling
  263. class ConvBiasForward(OpNode):
  264. type = "ConvBias"
  265. opdef = builtin.ConvBias
  266. @classmethod
  267. def load(cls, opr):
  268. obj = super(ConvBiasForward, cls).load(opr)
  269. obj.params["dtype"] = opr.outputs[0].dtype
  270. return obj
  271. @register_flops(
  272. ConvolutionForward, ConvBiasForward,
  273. )
  274. def flops_conv(opnode: ConvolutionForward, inputs, outputs):
  275. param_W_shape = inputs[1].shape
  276. kh = param_W_shape[-2]
  277. kw = param_W_shape[-1]
  278. if len(param_W_shape) == 5:
  279. num_input = param_W_shape[2]
  280. else:
  281. num_input = param_W_shape[1]
  282. NCHW = np.prod(outputs[0].shape)
  283. bias = 1 if isinstance(opnode, ConvBiasForward) else 0
  284. # N x Cout x H x W x (Cin x Kw x Kh)
  285. return NCHW * (num_input * kw * kh + bias)
  286. @register_receptive_field(ConvolutionForward, ConvBiasForward)
  287. def receptive_field(opnode: ConvolutionForward, inputs, outputs):
  288. pre_rf, pre_stride = preprocess_receptive_field(opnode, inputs, outputs)
  289. param_W_shape = inputs[1].shape
  290. kh = param_W_shape[-2]
  291. kw = param_W_shape[-1]
  292. rf = (
  293. kh * pre_stride[0] + pre_rf[0] - pre_stride[0],
  294. kw * pre_stride[1] + pre_rf[1] - pre_stride[1],
  295. )
  296. stride = (
  297. opnode.params["stride_h"] * pre_stride[0],
  298. opnode.params["stride_w"] * pre_stride[1],
  299. )
  300. opnode._rf = rf
  301. opnode._stride = stride
  302. return rf, stride
  303. class BatchConvBiasForward(OpNode):
  304. type = "BatchConvBias"
  305. opdef = builtin.BatchConvBias
  306. @classmethod
  307. def load(cls, opr):
  308. obj = super(BatchConvBiasForward, cls).load(opr)
  309. obj.params["dtype"] = opr.outputs[0].dtype
  310. return obj
  311. class BatchNormForward(OpNode):
  312. type = "BatchNorm"
  313. opdef = builtin.BatchNorm
  314. output_idx = -1
  315. class ROIAlignForward(OpNode):
  316. type = "ROIAlign"
  317. opdef = builtin.ROIAlign
  318. class WarpPerspectiveForward(OpNode):
  319. type = "WarpPerspective"
  320. opdef = builtin.WarpPerspective
  321. class WarpAffineForward(OpNode):
  322. type = "WarpAffine"
  323. opdef = builtin.WarpAffine
  324. class RemapForward(OpNode):
  325. type = "Remap"
  326. opdef = builtin.Remap
  327. class ResizeForward(OpNode):
  328. type = "Resize"
  329. opdef = builtin.Resize
  330. class IndexingOneHot(OpNode):
  331. type = "IndexingOneHot"
  332. opdef = builtin.IndexingOneHot
  333. class IndexingSetOneHot(OpNode):
  334. type = "IndexingSetOneHot"
  335. opdef = builtin.IndexingSetOneHot
  336. class Copy(OpNode):
  337. type = "Copy"
  338. opdef = builtin.Copy
  339. @classmethod
  340. def load(cls, opr):
  341. obj = super(Copy, cls).load(opr)
  342. obj.params["comp_node"] = opr.outputs[0].comp_node
  343. return obj
  344. class ArgsortForward(OpNode):
  345. type = "Argsort"
  346. opdef = builtin.Argsort
  347. class Argmax(OpNode):
  348. type = "Argmax"
  349. opdef = builtin.Argmax
  350. class Argmin(OpNode):
  351. type = "Argmin"
  352. opdef = builtin.Argmin
  353. class CondTake(OpNode):
  354. type = "CondTake"
  355. opdef = builtin.CondTake
  356. class TopK(OpNode):
  357. type = "TopK"
  358. opdef = builtin.TopK
  359. class NvOf(OpNode):
  360. type = "NvOf"
  361. opdef = builtin.NvOf
  362. class RNGOpr(OpNode):
  363. @classmethod
  364. def load(cls, opr):
  365. obj = super(RNGOpr, cls).load(opr)
  366. if len(obj.params) == 3:
  367. obj.opdef = builtin.GaussianRNG
  368. obj.type = "GaussianRNG"
  369. else:
  370. obj.opdef = builtin.UniformRNG
  371. obj.type = "UniformRNG"
  372. return obj
  373. class Linspace(OpNode):
  374. type = "Linspace"
  375. opdef = builtin.Linspace
  376. @classmethod
  377. def load(cls, opr):
  378. obj = super(Linspace, cls).load(opr)
  379. obj.params["comp_node"] = opr.outputs[0].comp_node
  380. return obj
  381. class Eye(OpNode):
  382. type = "Eye"
  383. opdef = builtin.Eye
  384. @classmethod
  385. def load(cls, opr):
  386. obj = super(Eye, cls).load(opr)
  387. obj.params["dtype"] = opr.outputs[0].dtype
  388. obj.params["comp_node"] = opr.outputs[0].comp_node
  389. return obj
  390. class GetVarShape(OpNode):
  391. type = "GetVarShape"
  392. opdef = builtin.GetVarShape
  393. class Concat(OpNode):
  394. type = "Concat"
  395. opdef = builtin.Concat
  396. @classmethod
  397. def load(cls, opr):
  398. obj = super(Concat, cls).load(opr)
  399. obj.params["comp_node"] = Device("xpux").to_c()
  400. return obj
  401. class Broadcast(OpNode):
  402. type = "Broadcast"
  403. opdef = builtin.Broadcast
  404. class Identity(OpNode):
  405. type = "Identity"
  406. opdef = builtin.Identity
  407. class NMSKeep(OpNode):
  408. type = "NMSKeep"
  409. opdef = builtin.NMSKeep
  410. # class ParamPackSplit
  411. # class ParamPackConcat
  412. class Dimshuffle(OpNode):
  413. type = "Dimshuffle"
  414. opdef = builtin.Dimshuffle
  415. @classmethod
  416. def load(cls, opr):
  417. obj = super(Dimshuffle, cls).load(opr)
  418. del obj.params["ndim"]
  419. return obj
  420. class Reshape(OpNode):
  421. type = "Reshape"
  422. opdef = builtin.Reshape
  423. class AxisAddRemove(OpNode):
  424. type = "AxisAddRemove"
  425. @classmethod
  426. def load(cls, opr):
  427. obj = cls()
  428. obj.name = opr.name
  429. obj._opr = opr
  430. params = json.loads(opr.params)
  431. desc = params["desc"]
  432. method = None
  433. axis = []
  434. for i in desc:
  435. if method is None:
  436. method = i["method"]
  437. assert method == i["method"]
  438. axis.append(i["axisnum"])
  439. obj.params = {"axis": axis}
  440. obj.opdef = builtin.AddAxis if desc[0]["method"] == 0 else builtin.RemoveAxis
  441. return obj
  442. class IndexingBase(OpNode):
  443. @classmethod
  444. def load(cls, opr):
  445. obj = cls()
  446. obj.name = opr.name
  447. obj._opr = opr
  448. params = json.loads(opr.params)
  449. items = [
  450. [
  451. p["axis"],
  452. bool(p["begin"]),
  453. bool(p["end"]),
  454. bool(p["step"]),
  455. bool(p["idx"]),
  456. ]
  457. for p in params
  458. ]
  459. obj.params["items"] = items
  460. return obj
  461. class Subtensor(IndexingBase):
  462. type = "Subtensor"
  463. opdef = builtin.Subtensor
  464. class SetSubtensor(IndexingBase):
  465. type = "SetSubtensor"
  466. opdef = builtin.SetSubtensor
  467. class IncrSubtensor(IndexingBase):
  468. type = "IncrSubtensor"
  469. opdef = builtin.IncrSubtensor
  470. class IndexingMultiAxisVec(IndexingBase):
  471. type = "IndexingMultiAxisVec"
  472. opdef = builtin.IndexingMultiAxisVec
  473. class IndexingSetMultiAxisVec(IndexingBase):
  474. type = "IndexingSetMultiAxisVec"
  475. opdef = builtin.IndexingSetMultiAxisVec
  476. class IndexingIncrMultiAxisVec(IndexingBase):
  477. type = "IndexingIncrMultiAxisVec"
  478. opdef = builtin.IndexingIncrMultiAxisVec
  479. class MeshIndexing(IndexingBase):
  480. type = "MeshIndexing"
  481. opdef = builtin.MeshIndexing
  482. class SetMeshIndexing(IndexingBase):
  483. type = "SetMeshIndexing"
  484. opdef = builtin.SetMeshIndexing
  485. class IncrMeshIndexing(IndexingBase):
  486. type = "IncrMeshIndexing"
  487. opdef = builtin.IncrMeshIndexing
  488. class BatchedMeshIndexing(IndexingBase):
  489. type = "BatchedMeshIndexing"
  490. opdef = builtin.BatchedMeshIndexing
  491. class BatchedSetMeshIndexing(IndexingBase):
  492. type = "BatchedSetMeshIndexing"
  493. opdef = builtin.BatchedSetMeshIndexing
  494. class BatchedIncrMeshIndexing(IndexingBase):
  495. type = "BatchedIncrMeshIndexing"
  496. opdef = builtin.BatchedIncrMeshIndexing
  497. # class CollectiveComm
  498. # class RemoteSend
  499. # class RemoteRecv
  500. # class TQT
  501. # class FakeQuant
  502. # class InplaceAdd
  503. class AssertEqual(OpNode):
  504. type = "AssertEqual"
  505. opdef = builtin.AssertEqual
  506. class CvtColorForward(OpNode):
  507. type = "CvtColor"
  508. opdef = builtin.CvtColor

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台