You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_node.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import json
  10. import sys
  11. from typing import Callable
  12. from ..core import _imperative_rt as rt
  13. from ..core._wrap import Device
  14. from ..core.ops import builtin
  15. from ..core.tensor.megbrain_graph import InputNode
  16. from ..tensor import Tensor
  17. from .comp_graph_tools import replace_vars
  18. class NetworkNode:
  19. pass
  20. class VarNode(NetworkNode):
  21. def __init__(self, owner_opr=None, name=None):
  22. self.var = None
  23. self.owner = owner_opr
  24. self.name = name
  25. self.id = id(self)
  26. @classmethod
  27. def load(cls, sym_var, owner_opr):
  28. obj = cls()
  29. obj.var = sym_var # mgb varnode
  30. obj.name = sym_var.name
  31. obj.owner = owner_opr
  32. return obj
  33. @property
  34. def shape(self):
  35. rst = None
  36. if self.var:
  37. try:
  38. rst = self.var.shape
  39. except:
  40. rst = None
  41. return rst
  42. @property
  43. def dtype(self):
  44. return self.var.dtype if self.var else None
  45. def set_owner_opr(self, owner_opr):
  46. self.owner_opr = owner_opr
  47. class OpNode(NetworkNode):
  48. opdef = None
  49. type = None
  50. def __init__(self):
  51. self.inputs = []
  52. self.outputs = []
  53. self.params = {}
  54. self._opr = None # mgb opnode
  55. self.id = id(self)
  56. @classmethod
  57. def load(cls, opr):
  58. obj = cls()
  59. obj.params = json.loads(opr.params)
  60. obj.name = opr.name
  61. obj._opr = opr
  62. return obj
  63. def compile(self, graph=None):
  64. op = self.opdef(**self.params)
  65. args = [i.var for i in self.inputs]
  66. outputs = rt.invoke_op(op, args)
  67. assert len(outputs) == len(self.outputs)
  68. self._opr = outputs[0].owner
  69. for i in range(len(self.outputs)):
  70. self.outputs[i].var = outputs[i]
  71. self.outputs[i].var.name = self.outputs[i].name
  72. assert self.outputs[i].owner is self
  73. def add_inp_var(self, x):
  74. self.inputs.append(x)
  75. def add_out_var(self, x):
  76. self.outputs.append(x)
  77. def str_to_mge_class(classname):
  78. # TODO: use megbrain C++ RTTI to replace type string
  79. if classname == "RNGOpr<MegDNNOpr>":
  80. classname = "RNGOpr"
  81. oprcls = getattr(sys.modules[__name__], classname, None)
  82. return oprcls if oprcls else ReadOnlyOpNode
  83. class Host2DeviceCopy(OpNode):
  84. type = "Host2DeviceCopy"
  85. def __init__(self, shape=None, dtype=None, name=None, device=None):
  86. super().__init__()
  87. self.shape = shape
  88. self.dtype = dtype
  89. self.name = name
  90. self.device = Device(device).to_c() if device else Device("xpux").to_c()
  91. self.outputs = []
  92. @classmethod
  93. def load(cls, opr):
  94. self = cls()
  95. self.outputs = []
  96. assert len(opr.outputs) == 1, "wrong number of outputs"
  97. self.shape = opr.outputs[0].shape
  98. self.dtype = opr.outputs[0].dtype
  99. self.name = opr.outputs[0].name
  100. self.device = opr.outputs[0].comp_node
  101. self._opr = opr
  102. return self
  103. def compile(self, graph):
  104. outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name)
  105. self._opr = outputs.owner
  106. if len(self.outputs) == 0:
  107. self.outputs.append(VarNode(self, self.name))
  108. self.outputs[0].var = outputs
  109. assert self.outputs[0].owner is self
  110. class ImmutableTensor(OpNode):
  111. type = "ImmutableTensor"
  112. def __init__(self, data=None, name=None, device=None, graph=None):
  113. super().__init__()
  114. self.name = name
  115. self.outputs = []
  116. self.graph = graph
  117. if data is not None:
  118. self.set_value(data, device)
  119. @property
  120. def device(self):
  121. return self._opr.outputs[0].comp_node if self._opr else None
  122. @device.setter
  123. def device(self, device):
  124. self.set_value(self.numpy(), device)
  125. @property
  126. def shape(self):
  127. return self.outputs[0].shape
  128. @property
  129. def dtype(self):
  130. return self._opr.outputs[0].dtype if self._opr else None
  131. def numpy(self):
  132. return self._opr.outputs[0].value if self._opr else None
  133. def set_value(self, data, device=None):
  134. assert self.graph is not None
  135. cn = device if device else self.device
  136. assert isinstance(data, (int, float, np.ndarray))
  137. if isinstance(data, (int, float)):
  138. data = np.array(data)
  139. if data.dtype == np.float64:
  140. data = data.astype(np.float32)
  141. elif data.dtype == np.int64:
  142. data = data.astype(np.int32)
  143. varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name)
  144. if len(self.outputs) == 0:
  145. self.outputs.append(VarNode(self, self.name))
  146. self.outputs[0].var = varnode
  147. self._opr = varnode.owner
  148. @classmethod
  149. def load(cls, opr):
  150. self = cls()
  151. self.outputs = []
  152. self._opr = opr
  153. self.name = opr.outputs[0].name
  154. self.graph = opr.graph
  155. return self
  156. def compile(self, graph):
  157. assert self.outputs[0].var is self._opr.outputs[0]
  158. assert self.outputs[0].owner is self
  159. if self.graph != graph:
  160. self.graph = graph
  161. self.set_value(self.numpy())
  162. if self.name is not None:
  163. self.outputs[0].var.name = self.name
  164. class ReadOnlyOpNode(OpNode):
  165. @classmethod
  166. def load(cls, opr):
  167. obj = super(ReadOnlyOpNode, cls).load(opr)
  168. obj.type = opr.type
  169. return obj
  170. def compile(self):
  171. assert self._opr is not None
  172. assert len(self.inputs) == len(self._opr.inputs)
  173. assert len(self.outputs) == len(self._opr.outputs)
  174. repl_dict = {}
  175. for ind, i in enumerate(self.inputs):
  176. if i.var != self._opr.inputs[ind]:
  177. repl_dict[self._opr.inputs[ind]] = i.var
  178. if bool(repl_dict):
  179. out_vars = replace_vars(self._opr.outputs, repl_dict)
  180. for ind, o in enumerate(self.outputs):
  181. o.var = out_vars[ind]
  182. class Elemwise(OpNode):
  183. type = "Elemwise"
  184. opdef = builtin.Elemwise
  185. class Reduce(OpNode):
  186. type = "Reduce"
  187. opdef = builtin.Reduce
  188. class TypeCvt(OpNode):
  189. type = "TypeCvt"
  190. opdef = builtin.TypeCvt
  191. @classmethod
  192. def load(cls, opr):
  193. obj = super(TypeCvt, cls).load(opr)
  194. t_dtype = opr.outputs[0].dtype
  195. obj.params["dtype"] = t_dtype
  196. return obj
  197. class MatrixInverse(OpNode):
  198. type = "MatrixInverse"
  199. opdef = builtin.MatrixInverse
  200. class MatrixMul(OpNode):
  201. type = "MatrixMul"
  202. opdef = builtin.MatrixMul
  203. class BatchedMatrixMul(OpNode):
  204. type = "BatchedMatmul"
  205. opdef = builtin.BatchedMatrixMul
  206. class Dot(OpNode):
  207. type = "Dot"
  208. opdef = builtin.Dot
  209. class SVD(OpNode):
  210. type = "SVD"
  211. opdef = builtin.SVD
  212. class ConvolutionForward(OpNode):
  213. type = "Convolution"
  214. opdef = builtin.Convolution
  215. class ConvolutionBackwardData(OpNode):
  216. type = "ConvTranspose"
  217. opdef = builtin.ConvolutionBackwardData
  218. class DeformableConvForward(OpNode):
  219. type = "DeformableConv"
  220. opdef = builtin.DeformableConv
  221. class GroupLocalForward(OpNode):
  222. type = "GroupLocal"
  223. opdef = builtin.GroupLocal
  224. class PoolingForward(OpNode):
  225. type = "Pooling"
  226. opdef = builtin.Pooling
  227. class AdaptivePoolingForward(OpNode):
  228. type = "AdaptivePooling"
  229. opdef = builtin.AdaptivePooling
  230. class ROIPoolingForward(OpNode):
  231. type = "ROIPooling"
  232. opdef = builtin.ROIPooling
  233. class DeformablePSROIPoolingForward(OpNode):
  234. type = "DeformablePSROIPooling"
  235. opdef = builtin.DeformablePSROIPooling
  236. class ConvBiasForward(OpNode):
  237. type = "ConvBias"
  238. opdef = builtin.ConvBias
  239. @classmethod
  240. def load(cls, opr):
  241. obj = super(ConvBiasForward, cls).load(opr)
  242. obj.params["dtype"] = opr.outputs[0].dtype
  243. return obj
  244. class BatchConvBiasForward(OpNode):
  245. type = "BatchConvBias"
  246. opdef = builtin.BatchConvBias
  247. @classmethod
  248. def load(cls, opr):
  249. obj = super(BatchConvBiasForward, cls).load(opr)
  250. obj.params["dtype"] = opr.outputs[0].dtype
  251. return obj
  252. class BatchNormForward(OpNode):
  253. type = "BatchNorm"
  254. opdef = builtin.BatchNorm
  255. class ROIAlignForward(OpNode):
  256. type = "ROIAlign"
  257. opdef = builtin.ROIAlign
  258. class WarpPerspectiveForward(OpNode):
  259. type = "WarpPerspective"
  260. opdef = builtin.WarpPerspective
  261. class WarpAffineForward(OpNode):
  262. type = "WarpAffine"
  263. opdef = builtin.WarpAffine
  264. class RemapForward(OpNode):
  265. type = "Remap"
  266. opdef = builtin.Remap
  267. class ResizeForward(OpNode):
  268. type = "Resize"
  269. opdef = builtin.Resize
  270. class IndexingOneHot(OpNode):
  271. type = "IndexingOneHot"
  272. opdef = builtin.IndexingOneHot
  273. class IndexingSetOneHot(OpNode):
  274. type = "IndexingSetOneHot"
  275. opdef = builtin.IndexingSetOneHot
  276. class Copy(OpNode):
  277. type = "Copy"
  278. opdef = builtin.Copy
  279. @classmethod
  280. def load(cls, opr):
  281. obj = super(Copy, cls).load(opr)
  282. obj.params["comp_node"] = opr.outputs[0].comp_node
  283. return obj
  284. class ArgsortForward(OpNode):
  285. type = "Argsort"
  286. opdef = builtin.Argsort
  287. class Argmax(OpNode):
  288. type = "Argmax"
  289. opdef = builtin.Argmax
  290. class Argmin(OpNode):
  291. type = "Argmin"
  292. opdef = builtin.Argmin
  293. class CondTake(OpNode):
  294. type = "CondTake"
  295. opdef = builtin.CondTake
  296. class TopK(OpNode):
  297. type = "TopK"
  298. opdef = builtin.TopK
  299. class NvOf(OpNode):
  300. type = "NvOf"
  301. opdef = builtin.NvOf
  302. class RNGOpr(OpNode):
  303. @classmethod
  304. def load(cls, opr):
  305. obj = super(RNGOpr, cls).load(opr)
  306. if len(obj.params) == 3:
  307. obj.opdef = builtin.GaussianRNG
  308. obj.type = "GaussianRNG"
  309. else:
  310. obj.opdef = builtin.UniformRNG
  311. obj.type = "UniformRNG"
  312. return obj
  313. class Linspace(OpNode):
  314. type = "Linspace"
  315. opdef = builtin.Linspace
  316. @classmethod
  317. def load(cls, opr):
  318. obj = super(Linspace, cls).load(opr)
  319. obj.params["comp_node"] = opr.outputs[0].comp_node
  320. return obj
  321. class Eye(OpNode):
  322. type = "Eye"
  323. opdef = builtin.Eye
  324. @classmethod
  325. def load(cls, opr):
  326. obj = super(Eye, cls).load(opr)
  327. obj.params["dtype"] = opr.outputs[0].dtype
  328. obj.params["comp_node"] = opr.outputs[0].comp_node
  329. return obj
  330. class GetVarShape(OpNode):
  331. type = "GetVarShape"
  332. opdef = builtin.GetVarShape
  333. class Concat(OpNode):
  334. type = "Concat"
  335. opdef = builtin.Concat
  336. @classmethod
  337. def load(cls, opr):
  338. obj = super(Concat, cls).load(opr)
  339. obj.params["comp_node"] = Device("xpux").to_c()
  340. return obj
  341. class Broadcast(OpNode):
  342. type = "Broadcast"
  343. opdef = builtin.Broadcast
  344. class Identity(OpNode):
  345. type = "Identity"
  346. opdef = builtin.Identity
  347. class NMSKeep(OpNode):
  348. type = "NMSKeep"
  349. opdef = builtin.NMSKeep
  350. # class ParamPackSplit
  351. # class ParamPackConcat
  352. class Dimshuffle(OpNode):
  353. type = "Dimshuffle"
  354. opdef = builtin.Dimshuffle
  355. @classmethod
  356. def load(cls, opr):
  357. obj = super(Dimshuffle, cls).load(opr)
  358. del obj.params["ndim"]
  359. return obj
  360. class Reshape(OpNode):
  361. type = "Reshape"
  362. opdef = builtin.Reshape
  363. class AxisAddRemove(OpNode):
  364. type = "AxisAddRemove"
  365. @classmethod
  366. def load(cls, opr):
  367. obj = cls()
  368. obj.name = opr.name
  369. obj._opr = opr
  370. params = json.loads(opr.params)
  371. desc = params["desc"]
  372. method = None
  373. axis = []
  374. for i in desc:
  375. if method is None:
  376. method = i["method"]
  377. assert method == i["method"]
  378. axis.append(i["axisnum"])
  379. obj.params = {"axis": axis}
  380. obj.opdef = builtin.AddAxis if desc[0]["method"] == 0 else builtin.RemoveAxis
  381. return obj
  382. class IndexingBase(OpNode):
  383. @classmethod
  384. def load(cls, opr):
  385. obj = cls()
  386. obj.name = opr.name
  387. obj._opr = opr
  388. params = json.loads(opr.params)
  389. items = [
  390. [
  391. p["axis"],
  392. bool(p["begin"]),
  393. bool(p["end"]),
  394. bool(p["step"]),
  395. bool(p["idx"]),
  396. ]
  397. for p in params
  398. ]
  399. obj.params["items"] = items
  400. return obj
  401. class Subtensor(IndexingBase):
  402. type = "Subtensor"
  403. opdef = builtin.Subtensor
  404. class SetSubtensor(IndexingBase):
  405. type = "SetSubtensor"
  406. opdef = builtin.SetSubtensor
  407. class IncrSubtensor(IndexingBase):
  408. type = "IncrSubtensor"
  409. opdef = builtin.IncrSubtensor
  410. class IndexingMultiAxisVec(IndexingBase):
  411. type = "IndexingMultiAxisVec"
  412. opdef = builtin.IndexingMultiAxisVec
  413. class IndexingSetMultiAxisVec(IndexingBase):
  414. type = "IndexingSetMultiAxisVec"
  415. opdef = builtin.IndexingSetMultiAxisVec
  416. class IndexingIncrMultiAxisVec(IndexingBase):
  417. type = "IndexingIncrMultiAxisVec"
  418. opdef = builtin.IndexingIncrMultiAxisVec
  419. class MeshIndexing(IndexingBase):
  420. type = "MeshIndexing"
  421. opdef = builtin.MeshIndexing
  422. class SetMeshIndexing(IndexingBase):
  423. type = "SetMeshIndexing"
  424. opdef = builtin.SetMeshIndexing
  425. class IncrMeshIndexing(IndexingBase):
  426. type = "IncrMeshIndexing"
  427. opdef = builtin.IncrMeshIndexing
  428. class BatchedMeshIndexing(IndexingBase):
  429. type = "BatchedMeshIndexing"
  430. opdef = builtin.BatchedMeshIndexing
  431. class BatchedSetMeshIndexing(IndexingBase):
  432. type = "BatchedSetMeshIndexing"
  433. opdef = builtin.BatchedSetMeshIndexing
  434. class BatchedIncrMeshIndexing(IndexingBase):
  435. type = "BatchedIncrMeshIndexing"
  436. opdef = builtin.BatchedIncrMeshIndexing
  437. # class CollectiveComm
  438. # class RemoteSend
  439. # class RemoteRecv
  440. # class TQT
  441. # class FakeQuant
  442. # class InplaceAdd
  443. class AssertEqual(OpNode):
  444. type = "AssertEqual"
  445. opdef = builtin.AssertEqual
  446. class ElemwiseMultiType(OpNode):
  447. type = "ElemwiseMultiType"
  448. opdef = builtin.ElemwiseMultiType
  449. @classmethod
  450. def load(cls, opr):
  451. obj = super(ElemwiseMultiType, cls).load(opr)
  452. obj.params["dtype"] = opr.outputs[0].dtype
  453. return obj
  454. class CvtColorForward(OpNode):
  455. type = "CvtColor"
  456. opdef = builtin.CvtColor

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台