You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_node.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import json
  10. import sys
  11. from typing import Callable
  12. import numpy as np
  13. from ..core import _imperative_rt as rt
  14. from ..core._wrap import Device
  15. from ..core.ops import builtin
  16. from ..core.tensor.megbrain_graph import InputNode
  17. from ..tensor import Tensor
  18. from .comp_graph_tools import replace_vars
  19. class NetworkNode:
  20. pass
  21. class VarNode(NetworkNode):
  22. def __init__(self, owner_opr=None, name=None):
  23. self.var = None
  24. self.owner = owner_opr
  25. self.name = name
  26. self.id = id(self)
  27. @classmethod
  28. def load(cls, sym_var, owner_opr):
  29. obj = cls()
  30. obj.var = sym_var # mgb varnode
  31. obj.name = sym_var.name
  32. obj.owner = owner_opr
  33. return obj
  34. @property
  35. def shape(self):
  36. rst = None
  37. if self.var:
  38. try:
  39. rst = self.var.shape
  40. except:
  41. rst = None
  42. return rst
  43. @property
  44. def dtype(self):
  45. return self.var.dtype if self.var else None
  46. def set_owner_opr(self, owner_opr):
  47. self.owner = owner_opr
  48. class OpNode(NetworkNode):
  49. opdef = None
  50. type = None
  51. def __init__(self):
  52. self.inputs = []
  53. self.outputs = []
  54. self.params = {}
  55. self._opr = None # mgb opnode
  56. self.id = id(self)
  57. @classmethod
  58. def load(cls, opr):
  59. obj = cls()
  60. obj.params = json.loads(opr.params)
  61. obj.name = opr.name
  62. obj._opr = opr
  63. return obj
  64. def compile(self, graph=None):
  65. op = self.opdef(**self.params)
  66. args = [i.var for i in self.inputs]
  67. outputs = rt.invoke_op(op, args)
  68. assert len(outputs) == len(self.outputs)
  69. self._opr = outputs[0].owner
  70. for i in range(len(self.outputs)):
  71. self.outputs[i].var = outputs[i]
  72. self.outputs[i].var.name = self.outputs[i].name
  73. assert self.outputs[i].owner is self
  74. def add_inp_var(self, x):
  75. self.inputs.append(x)
  76. def add_out_var(self, x):
  77. self.outputs.append(x)
  78. def str_to_mge_class(classname):
  79. # TODO: use megbrain C++ RTTI to replace type string
  80. if classname == "RNGOpr<MegDNNOpr>":
  81. classname = "RNGOpr"
  82. oprcls = getattr(sys.modules[__name__], classname, None)
  83. return oprcls if oprcls else ReadOnlyOpNode
  84. class Host2DeviceCopy(OpNode):
  85. type = "Host2DeviceCopy"
  86. def __init__(self, shape=None, dtype=None, name=None, device=None):
  87. super().__init__()
  88. self.shape = shape
  89. self.dtype = dtype
  90. self.name = name
  91. self.device = Device(device).to_c() if device else Device("xpux").to_c()
  92. self.outputs = []
  93. @classmethod
  94. def load(cls, opr):
  95. self = cls()
  96. self.outputs = []
  97. assert len(opr.outputs) == 1, "wrong number of outputs"
  98. self.shape = opr.outputs[0].shape
  99. self.dtype = opr.outputs[0].dtype
  100. self.name = opr.outputs[0].name
  101. self.device = opr.outputs[0].comp_node
  102. self._opr = opr
  103. return self
  104. def compile(self, graph):
  105. outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name)
  106. self._opr = outputs.owner
  107. if len(self.outputs) == 0:
  108. self.outputs.append(VarNode(self, self.name))
  109. self.outputs[0].var = outputs
  110. assert self.outputs[0].owner is self
  111. class ImmutableTensor(OpNode):
  112. type = "ImmutableTensor"
  113. def __init__(self, data=None, name=None, device=None, graph=None):
  114. super().__init__()
  115. self.name = name
  116. self.outputs = []
  117. self.graph = graph
  118. if data is not None:
  119. self.set_value(data, device)
  120. @property
  121. def device(self):
  122. return self._opr.outputs[0].comp_node if self._opr else None
  123. @device.setter
  124. def device(self, device):
  125. self.set_value(self.numpy(), device)
  126. @property
  127. def shape(self):
  128. return self.outputs[0].shape
  129. @property
  130. def dtype(self):
  131. return self._opr.outputs[0].dtype if self._opr else None
  132. def numpy(self):
  133. return self._opr.outputs[0].value if self._opr else None
  134. def set_value(self, data, device=None):
  135. assert self.graph is not None
  136. cn = device if device else self.device
  137. assert isinstance(data, (int, float, np.ndarray))
  138. if isinstance(data, (int, float)):
  139. data = np.array(data)
  140. if data.dtype == np.float64:
  141. data = data.astype(np.float32)
  142. elif data.dtype == np.int64:
  143. data = data.astype(np.int32)
  144. varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name)
  145. if len(self.outputs) == 0:
  146. self.outputs.append(VarNode(self, self.name))
  147. self.outputs[0].var = varnode
  148. self._opr = varnode.owner
  149. @classmethod
  150. def load(cls, opr):
  151. self = cls()
  152. self.outputs = []
  153. self._opr = opr
  154. self.name = opr.outputs[0].name
  155. self.graph = opr.graph
  156. return self
  157. def compile(self, graph):
  158. assert self.outputs[0].var is self._opr.outputs[0]
  159. assert self.outputs[0].owner is self
  160. if self.graph != graph:
  161. self.graph = graph
  162. self.set_value(self.numpy())
  163. if self.name is not None:
  164. self.outputs[0].var.name = self.name
  165. class ReadOnlyOpNode(OpNode):
  166. @classmethod
  167. def load(cls, opr):
  168. obj = super(ReadOnlyOpNode, cls).load(opr)
  169. obj.type = opr.type
  170. return obj
  171. def compile(self):
  172. assert self._opr is not None
  173. assert len(self.inputs) == len(self._opr.inputs)
  174. assert len(self.outputs) == len(self._opr.outputs)
  175. repl_dict = {}
  176. for ind, i in enumerate(self.inputs):
  177. if i.var != self._opr.inputs[ind]:
  178. repl_dict[self._opr.inputs[ind]] = i.var
  179. if bool(repl_dict):
  180. out_vars = replace_vars(self._opr.outputs, repl_dict)
  181. for ind, o in enumerate(self.outputs):
  182. o.var = out_vars[ind]
  183. class Elemwise(OpNode):
  184. type = "Elemwise"
  185. opdef = builtin.Elemwise
  186. def calc_flops(self):
  187. return np.prod(self.outputs[0].shape)
  188. class Reduce(OpNode):
  189. type = "Reduce"
  190. opdef = builtin.Reduce
  191. class TypeCvt(OpNode):
  192. type = "TypeCvt"
  193. opdef = builtin.TypeCvt
  194. @classmethod
  195. def load(cls, opr):
  196. obj = super(TypeCvt, cls).load(opr)
  197. t_dtype = opr.outputs[0].dtype
  198. obj.params["dtype"] = t_dtype
  199. return obj
  200. class MatrixInverse(OpNode):
  201. type = "MatrixInverse"
  202. opdef = builtin.MatrixInverse
  203. class MatrixMul(OpNode):
  204. type = "MatrixMul"
  205. opdef = builtin.MatrixMul
  206. def calc_flops(self):
  207. assert len(self.inputs[0].shape) == 2 and len(self.outputs[0].shape) == 2
  208. mid_shape = self.inputs[0].shape[1]
  209. return np.prod(self.outputs[0].shape) * mid_shape
  210. class BatchedMatrixMul(OpNode):
  211. type = "BatchedMatmul"
  212. opdef = builtin.BatchedMatrixMul
  213. def calc_flops(self):
  214. assert len(self.inputs[0].shape) == 3 and len(self.outputs[0].shape) == 3
  215. mid_shape = self.inputs[0].shape[2]
  216. return np.prod(self.outputs[0].shape) * mid_shape
  217. class Dot(OpNode):
  218. type = "Dot"
  219. opdef = builtin.Dot
  220. class SVD(OpNode):
  221. type = "SVD"
  222. opdef = builtin.SVD
  223. class ConvolutionForward(OpNode):
  224. type = "Convolution"
  225. opdef = builtin.Convolution
  226. def calc_flops(self):
  227. param_W_shape = self.inputs[1].shape
  228. kh = param_W_shape[-2]
  229. kw = param_W_shape[-1]
  230. if len(param_W_shape) == 5:
  231. num_input = param_W_shape[2]
  232. else:
  233. num_input = param_W_shape[1]
  234. NCHW = np.prod(self.outputs[0].shape)
  235. # N x Cout x H x W x (Cin x Kw x Kh)
  236. return NCHW * (num_input * kw * kh)
  237. class ConvolutionBackwardData(OpNode):
  238. type = "ConvTranspose"
  239. opdef = builtin.ConvolutionBackwardData
  240. class DeformableConvForward(OpNode):
  241. type = "DeformableConv"
  242. opdef = builtin.DeformableConv
  243. class GroupLocalForward(OpNode):
  244. type = "GroupLocal"
  245. opdef = builtin.GroupLocal
  246. class PoolingForward(OpNode):
  247. type = "Pooling"
  248. opdef = builtin.Pooling
  249. class AdaptivePoolingForward(OpNode):
  250. type = "AdaptivePooling"
  251. opdef = builtin.AdaptivePooling
  252. class ROIPoolingForward(OpNode):
  253. type = "ROIPooling"
  254. opdef = builtin.ROIPooling
  255. class DeformablePSROIPoolingForward(OpNode):
  256. type = "DeformablePSROIPooling"
  257. opdef = builtin.DeformablePSROIPooling
  258. class ConvBiasForward(OpNode):
  259. type = "ConvBias"
  260. opdef = builtin.ConvBias
  261. @classmethod
  262. def load(cls, opr):
  263. obj = super(ConvBiasForward, cls).load(opr)
  264. obj.params["dtype"] = opr.outputs[0].dtype
  265. return obj
  266. def calc_flops(self):
  267. param_W_shape = self.inputs[1].shape
  268. kh = param_W_shape[-2]
  269. kw = param_W_shape[-1]
  270. if len(param_W_shape) == 5:
  271. num_input = param_W_shape[2]
  272. else:
  273. num_input = param_W_shape[1]
  274. NCHW = np.prod(self.outputs[0].shape)
  275. # N x Cout x H x W x (Cin x Kw x Kh + bias)
  276. return NCHW * (num_input * kw * kh + 1)
  277. class BatchConvBiasForward(OpNode):
  278. type = "BatchConvBias"
  279. opdef = builtin.BatchConvBias
  280. @classmethod
  281. def load(cls, opr):
  282. obj = super(BatchConvBiasForward, cls).load(opr)
  283. obj.params["dtype"] = opr.outputs[0].dtype
  284. return obj
  285. class BatchNormForward(OpNode):
  286. type = "BatchNorm"
  287. opdef = builtin.BatchNorm
  288. output_idx = -1
  289. class ROIAlignForward(OpNode):
  290. type = "ROIAlign"
  291. opdef = builtin.ROIAlign
  292. class WarpPerspectiveForward(OpNode):
  293. type = "WarpPerspective"
  294. opdef = builtin.WarpPerspective
  295. class WarpAffineForward(OpNode):
  296. type = "WarpAffine"
  297. opdef = builtin.WarpAffine
  298. class RemapForward(OpNode):
  299. type = "Remap"
  300. opdef = builtin.Remap
  301. class ResizeForward(OpNode):
  302. type = "Resize"
  303. opdef = builtin.Resize
  304. class IndexingOneHot(OpNode):
  305. type = "IndexingOneHot"
  306. opdef = builtin.IndexingOneHot
  307. class IndexingSetOneHot(OpNode):
  308. type = "IndexingSetOneHot"
  309. opdef = builtin.IndexingSetOneHot
  310. class Copy(OpNode):
  311. type = "Copy"
  312. opdef = builtin.Copy
  313. @classmethod
  314. def load(cls, opr):
  315. obj = super(Copy, cls).load(opr)
  316. obj.params["comp_node"] = opr.outputs[0].comp_node
  317. return obj
  318. class ArgsortForward(OpNode):
  319. type = "Argsort"
  320. opdef = builtin.Argsort
  321. class Argmax(OpNode):
  322. type = "Argmax"
  323. opdef = builtin.Argmax
  324. class Argmin(OpNode):
  325. type = "Argmin"
  326. opdef = builtin.Argmin
  327. class CondTake(OpNode):
  328. type = "CondTake"
  329. opdef = builtin.CondTake
  330. class TopK(OpNode):
  331. type = "TopK"
  332. opdef = builtin.TopK
  333. class NvOf(OpNode):
  334. type = "NvOf"
  335. opdef = builtin.NvOf
  336. class RNGOpr(OpNode):
  337. @classmethod
  338. def load(cls, opr):
  339. obj = super(RNGOpr, cls).load(opr)
  340. if len(obj.params) == 3:
  341. obj.opdef = builtin.GaussianRNG
  342. obj.type = "GaussianRNG"
  343. else:
  344. obj.opdef = builtin.UniformRNG
  345. obj.type = "UniformRNG"
  346. return obj
  347. class Linspace(OpNode):
  348. type = "Linspace"
  349. opdef = builtin.Linspace
  350. @classmethod
  351. def load(cls, opr):
  352. obj = super(Linspace, cls).load(opr)
  353. obj.params["comp_node"] = opr.outputs[0].comp_node
  354. return obj
  355. class Eye(OpNode):
  356. type = "Eye"
  357. opdef = builtin.Eye
  358. @classmethod
  359. def load(cls, opr):
  360. obj = super(Eye, cls).load(opr)
  361. obj.params["dtype"] = opr.outputs[0].dtype
  362. obj.params["comp_node"] = opr.outputs[0].comp_node
  363. return obj
  364. class GetVarShape(OpNode):
  365. type = "GetVarShape"
  366. opdef = builtin.GetVarShape
  367. class Concat(OpNode):
  368. type = "Concat"
  369. opdef = builtin.Concat
  370. @classmethod
  371. def load(cls, opr):
  372. obj = super(Concat, cls).load(opr)
  373. obj.params["comp_node"] = Device("xpux").to_c()
  374. return obj
  375. class Broadcast(OpNode):
  376. type = "Broadcast"
  377. opdef = builtin.Broadcast
  378. class Identity(OpNode):
  379. type = "Identity"
  380. opdef = builtin.Identity
  381. class NMSKeep(OpNode):
  382. type = "NMSKeep"
  383. opdef = builtin.NMSKeep
  384. # class ParamPackSplit
  385. # class ParamPackConcat
  386. class Dimshuffle(OpNode):
  387. type = "Dimshuffle"
  388. opdef = builtin.Dimshuffle
  389. @classmethod
  390. def load(cls, opr):
  391. obj = super(Dimshuffle, cls).load(opr)
  392. del obj.params["ndim"]
  393. return obj
  394. class Reshape(OpNode):
  395. type = "Reshape"
  396. opdef = builtin.Reshape
  397. class AxisAddRemove(OpNode):
  398. type = "AxisAddRemove"
  399. @classmethod
  400. def load(cls, opr):
  401. obj = cls()
  402. obj.name = opr.name
  403. obj._opr = opr
  404. params = json.loads(opr.params)
  405. desc = params["desc"]
  406. method = None
  407. axis = []
  408. for i in desc:
  409. if method is None:
  410. method = i["method"]
  411. assert method == i["method"]
  412. axis.append(i["axisnum"])
  413. obj.params = {"axis": axis}
  414. obj.opdef = builtin.AddAxis if desc[0]["method"] == 0 else builtin.RemoveAxis
  415. return obj
  416. class IndexingBase(OpNode):
  417. @classmethod
  418. def load(cls, opr):
  419. obj = cls()
  420. obj.name = opr.name
  421. obj._opr = opr
  422. params = json.loads(opr.params)
  423. items = [
  424. [
  425. p["axis"],
  426. bool(p["begin"]),
  427. bool(p["end"]),
  428. bool(p["step"]),
  429. bool(p["idx"]),
  430. ]
  431. for p in params
  432. ]
  433. obj.params["items"] = items
  434. return obj
  435. class Subtensor(IndexingBase):
  436. type = "Subtensor"
  437. opdef = builtin.Subtensor
  438. class SetSubtensor(IndexingBase):
  439. type = "SetSubtensor"
  440. opdef = builtin.SetSubtensor
  441. class IncrSubtensor(IndexingBase):
  442. type = "IncrSubtensor"
  443. opdef = builtin.IncrSubtensor
  444. class IndexingMultiAxisVec(IndexingBase):
  445. type = "IndexingMultiAxisVec"
  446. opdef = builtin.IndexingMultiAxisVec
  447. class IndexingSetMultiAxisVec(IndexingBase):
  448. type = "IndexingSetMultiAxisVec"
  449. opdef = builtin.IndexingSetMultiAxisVec
  450. class IndexingIncrMultiAxisVec(IndexingBase):
  451. type = "IndexingIncrMultiAxisVec"
  452. opdef = builtin.IndexingIncrMultiAxisVec
  453. class MeshIndexing(IndexingBase):
  454. type = "MeshIndexing"
  455. opdef = builtin.MeshIndexing
  456. class SetMeshIndexing(IndexingBase):
  457. type = "SetMeshIndexing"
  458. opdef = builtin.SetMeshIndexing
  459. class IncrMeshIndexing(IndexingBase):
  460. type = "IncrMeshIndexing"
  461. opdef = builtin.IncrMeshIndexing
  462. class BatchedMeshIndexing(IndexingBase):
  463. type = "BatchedMeshIndexing"
  464. opdef = builtin.BatchedMeshIndexing
  465. class BatchedSetMeshIndexing(IndexingBase):
  466. type = "BatchedSetMeshIndexing"
  467. opdef = builtin.BatchedSetMeshIndexing
  468. class BatchedIncrMeshIndexing(IndexingBase):
  469. type = "BatchedIncrMeshIndexing"
  470. opdef = builtin.BatchedIncrMeshIndexing
  471. # class CollectiveComm
  472. # class RemoteSend
  473. # class RemoteRecv
  474. # class TQT
  475. # class FakeQuant
  476. # class InplaceAdd
  477. class AssertEqual(OpNode):
  478. type = "AssertEqual"
  479. opdef = builtin.AssertEqual
  480. class ElemwiseMultiType(OpNode):
  481. type = "ElemwiseMultiType"
  482. opdef = builtin.ElemwiseMultiType
  483. @classmethod
  484. def load(cls, opr):
  485. obj = super(ElemwiseMultiType, cls).load(opr)
  486. obj.params["dtype"] = opr.outputs[0].dtype
  487. return obj
  488. def calc_flops(self):
  489. return np.prod(self.outputs[0].shape)
  490. class CvtColorForward(OpNode):
  491. type = "CvtColor"
  492. opdef = builtin.CvtColor

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台