You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_node.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import json
  10. import sys
  11. from typing import Sequence
  12. import numpy as np
  13. from ..core import _imperative_rt as rt
  14. from ..core._imperative_rt.core2 import SymbolVar, apply
  15. from ..core._trace_option import use_symbolic_shape
  16. from ..core._wrap import Device
  17. from ..core.ops import builtin
  18. from ..core.tensor.array_method import ArrayMethodMixin
  19. from .comp_graph_tools import replace_vars
  20. from .module_stats import (
  21. preprocess_receptive_field,
  22. register_flops,
  23. register_receptive_field,
  24. )
  25. class NetworkNode:
  26. pass
  27. class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)):
  28. pass
  29. class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
  30. def __init__(self, var=None, *, owner_opr=None, name=None):
  31. SymbolVar.__init__(self, var)
  32. self.users = [] # List[OpNode]
  33. self.owner = owner_opr
  34. self.name = name
  35. self.id = id(self)
  36. @classmethod
  37. def load(cls, sym_var, owner_opr):
  38. obj = cls()
  39. obj.var = sym_var # mgb varnode
  40. obj.name = sym_var.name
  41. obj.owner = owner_opr
  42. return obj
  43. def _get_var_shape(self, axis=None):
  44. opdef = (
  45. builtin.GetVarShape() if axis is None else builtin.GetVarShape(axis=axis)
  46. )
  47. return apply(opdef, self)[0]
  48. @property
  49. def partial_shape(self):
  50. """Return the tuple type inferred shape of VarNode
  51. """
  52. return tuple(self._get_var_shape().numpy())
  53. def shapeof(self, axis):
  54. """Return the symbolic shape of axis
  55. """
  56. return self._get_var_shape(axis=axis) if self.var else None
  57. @property
  58. def _tuple_shape(self):
  59. return self.partial_shape
  60. @property
  61. def shape(self):
  62. """Return the symbolic shape if using set_symbolic_shape(True)
  63. else inferred shape
  64. """
  65. rst = None
  66. if self.var:
  67. try:
  68. rst = self.var.shape
  69. except:
  70. rst = None
  71. if not use_symbolic_shape():
  72. return rst
  73. return self._get_var_shape() if self.var else None
  74. @property
  75. def dtype(self):
  76. return self.var.dtype if self.var else None
  77. @property
  78. def ndim(self):
  79. return super().ndim
  80. def __bool__(self):
  81. return False
  82. __index__ = None
  83. __int__ = None
  84. __float__ = None
  85. __complex__ = None
  86. def __hash__(self):
  87. return id(self)
  88. def numpy(self):
  89. return super().numpy()
  90. def _reset(self, other):
  91. if not isinstance(other, VarNode):
  92. assert self.graph, "VarNode _reset must have graph"
  93. node = ImmutableTensor(other, graph=self.graph)
  94. node.compile(self.graph)
  95. other = node.outputs[0]
  96. if self.owner is not None:
  97. idx = self.owner.outputs.index(self)
  98. self.owner.outputs[idx] = VarNode(
  99. self.var, owner_opr=self.owner, name=self.var.name
  100. )
  101. self.var = other.var
  102. self.owner = None
  103. def set_owner_opr(self, owner_opr):
  104. self.owner = owner_opr
  105. class OpNode(NetworkNode):
  106. opdef = None
  107. type = None
  108. def __init__(self):
  109. self.inputs = []
  110. self.outputs = []
  111. self.params = {}
  112. self._opr = None # mgb opnode
  113. @property
  114. def id(self):
  115. return id(self)
  116. @property
  117. def priority(self):
  118. if self._opr is not None:
  119. return (self._opr.priority, self._opr.id)
  120. return (0, 0)
  121. @classmethod
  122. def load(cls, opr):
  123. obj = cls()
  124. obj.params = json.loads(opr.params)
  125. obj.name = opr.name
  126. obj._opr = opr
  127. return obj
  128. def compile(self):
  129. if (
  130. self._opr is None
  131. or len(self._opr.inputs) != len(self.inputs)
  132. or any([i != j.var for i, j in zip(self._opr.inputs, self.inputs)])
  133. ):
  134. op = self.opdef(**self.params)
  135. args = [i.var for i in self.inputs]
  136. outputs = rt.invoke_op(op, args)
  137. assert len(outputs) == len(self.outputs)
  138. self._opr = outputs[0].owner
  139. for i in range(len(self.outputs)):
  140. self.outputs[i].var = outputs[i]
  141. self.outputs[i].var.name = self.outputs[i].name
  142. assert self.outputs[i].owner is self
  143. def add_inp_var(self, x):
  144. self.inputs.append(x)
  145. def add_out_var(self, x):
  146. self.outputs.append(x)
  147. def __repr__(self):
  148. return "%s{%s}" % (self.name, self.type)
  149. def str_to_mge_class(classname):
  150. # TODO: use megbrain C++ RTTI to replace type string
  151. if classname == "RNGOpr<MegDNNOpr>":
  152. classname = "RNGOpr"
  153. oprcls = getattr(sys.modules[__name__], classname, None)
  154. return oprcls if oprcls else ReadOnlyOpNode
  155. class Host2DeviceCopy(OpNode):
  156. type = "Host2DeviceCopy"
  157. def __init__(self, shape=None, dtype=None, name=None, device=None):
  158. super().__init__()
  159. self.shape = shape
  160. self.dtype = dtype
  161. self.name = name
  162. self.device = Device(device).to_c() if device else Device("xpux").to_c()
  163. self.outputs = []
  164. @classmethod
  165. def load(cls, opr):
  166. self = cls()
  167. self.outputs = []
  168. assert len(opr.outputs) == 1, "wrong number of outputs"
  169. self.shape = opr.outputs[0].shape
  170. self.dtype = opr.outputs[0].dtype
  171. self.name = opr.outputs[0].name
  172. self.device = opr.outputs[0].comp_node
  173. self._opr = opr
  174. return self
  175. def compile(self, graph):
  176. if (
  177. self._opr is None
  178. or self._opr.graph != graph
  179. or self._opr.outputs[0].comp_node != self.device
  180. or self._opr.outputs[0].shape != self.shape
  181. or self._opr.outputs[0].dtype != self.dtype
  182. ):
  183. outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name)
  184. self._opr = outputs.owner
  185. if len(self.outputs) == 0:
  186. self.outputs.append(VarNode(owner_opr=self, name=self.name))
  187. self.outputs[0].var = outputs
  188. assert self.outputs[0].owner is self
  189. class ConstOpBase(OpNode):
  190. type = "ConstOpBase"
  191. def __init__(self, data=None, name=None, device=None, graph=None):
  192. assert type(self) is not ConstOpBase, "ConstOpBase cannot be instantiated"
  193. super().__init__()
  194. self.name = name
  195. self.outputs = []
  196. self.graph = graph
  197. if data is not None:
  198. self.set_value(data, device)
  199. @property
  200. def device(self):
  201. return self._opr.outputs[0].comp_node if self._opr else None
  202. @device.setter
  203. def device(self, device):
  204. self.set_value(self.numpy(), device)
  205. @property
  206. def shape(self):
  207. return self.outputs[0].shape
  208. @property
  209. def dtype(self):
  210. return self._opr.outputs[0].dtype if self._opr else None
  211. def numpy(self):
  212. return self.outputs[0].numpy()
  213. def set_value(self, data, device=None):
  214. assert self.graph is not None
  215. cn = device if device else self.device
  216. assert isinstance(data, (int, float, Sequence, np.ndarray))
  217. if not isinstance(data, np.ndarray):
  218. data = np.array(data)
  219. if data.dtype == np.float64:
  220. data = data.astype(np.float32)
  221. elif data.dtype == np.int64:
  222. data = data.astype(np.int32)
  223. varnode = type(self).rt_fun(self.graph, data, cn, data.dtype, self.name)
  224. if len(self.outputs) == 0:
  225. self.outputs.append(VarNode(owner_opr=self, name=self.name))
  226. self.outputs[0].var = varnode
  227. self._opr = varnode.owner
  228. @classmethod
  229. def load(cls, opr):
  230. self = cls()
  231. self.outputs = []
  232. self._opr = opr
  233. self.name = opr.outputs[0].name
  234. self.graph = opr.graph
  235. return self
  236. def compile(self, graph):
  237. assert self.outputs[0].var is self._opr.outputs[0]
  238. assert self.outputs[0].owner is self
  239. if self.graph != graph:
  240. self.graph = graph
  241. self.set_value(self.numpy())
  242. if self.name is not None:
  243. self.outputs[0].var.name = self.name
  244. class ImmutableTensor(ConstOpBase):
  245. type = "ImmutableTensor"
  246. rt_fun = rt.make_const
  247. class SharedDeviceTensor(ConstOpBase):
  248. type = "SharedDeviceTensor"
  249. rt_fun = rt.make_shared
  250. class ReadOnlyOpNode(OpNode):
  251. @classmethod
  252. def load(cls, opr):
  253. obj = super(ReadOnlyOpNode, cls).load(opr)
  254. obj.type = opr.type
  255. return obj
  256. def compile(self):
  257. assert self._opr is not None
  258. assert len(self.inputs) == len(self._opr.inputs)
  259. assert len(self.outputs) == len(self._opr.outputs)
  260. repl_dict = {}
  261. for ind, i in enumerate(self.inputs):
  262. if i.var != self._opr.inputs[ind]:
  263. repl_dict[self._opr.inputs[ind]] = i.var
  264. if bool(repl_dict):
  265. out_vars = replace_vars(self._opr.outputs, repl_dict)
  266. for ind, o in enumerate(self.outputs):
  267. o.var = out_vars[ind]
  268. class Elemwise(OpNode):
  269. type = "Elemwise"
  270. opdef = builtin.Elemwise
  271. def __repr__(self):
  272. return "%s{Elemwise:%s}" % (self.name, self.params["mode"])
  273. class ElemwiseMultiType(OpNode):
  274. type = "ElemwiseMultiType"
  275. opdef = builtin.ElemwiseMultiType
  276. def __repr__(self):
  277. return "%s{ElemwiseMultiType:%s}" % (self.name, self.params["mode"])
  278. @classmethod
  279. def load(cls, opr):
  280. obj = super(ElemwiseMultiType, cls).load(opr)
  281. obj.params["dtype"] = opr.outputs[0].dtype
  282. return obj
  283. @register_flops(Elemwise, ElemwiseMultiType)
  284. def flops_elemwise(opnode: Elemwise, inputs, outputs):
  285. return np.prod(outputs[0].shape)
  286. class Reduce(OpNode):
  287. type = "Reduce"
  288. opdef = builtin.Reduce
  289. class TypeCvt(OpNode):
  290. type = "TypeCvt"
  291. opdef = builtin.TypeCvt
  292. @classmethod
  293. def load(cls, opr):
  294. obj = super(TypeCvt, cls).load(opr)
  295. t_dtype = opr.outputs[0].dtype
  296. obj.params["dtype"] = t_dtype
  297. return obj
  298. class MatrixInverse(OpNode):
  299. type = "MatrixInverse"
  300. opdef = builtin.MatrixInverse
  301. class MatrixMul(OpNode):
  302. type = "MatrixMul"
  303. opdef = builtin.MatrixMul
  304. @register_flops(MatrixMul)
  305. def flops_matmul(opnode: MatrixMul, inputs, outputs):
  306. assert len(inputs[0].shape) == 2 and len(outputs[0].shape) == 2
  307. mid_shape = inputs[0].shape[1]
  308. return np.prod(outputs[0].shape) * mid_shape
  309. class BatchedMatrixMul(OpNode):
  310. type = "BatchedMatmul"
  311. opdef = builtin.BatchedMatrixMul
  312. @register_flops(BatchedMatrixMul)
  313. def flops_batchmatmul(opnode: BatchedMatrixMul, inputs, outputs):
  314. assert len(inputs[0].shape) == 3 and len(outputs[0].shape) == 3
  315. mid_shape = inputs[0].shape[2]
  316. return np.prod(outputs[0].shape) * mid_shape
  317. class Dot(OpNode):
  318. type = "Dot"
  319. opdef = builtin.Dot
  320. class SVD(OpNode):
  321. type = "SVD"
  322. opdef = builtin.SVD
  323. class ConvolutionForward(OpNode):
  324. type = "Convolution"
  325. opdef = builtin.Convolution
  326. class ConvolutionBackwardData(OpNode):
  327. type = "ConvTranspose"
  328. opdef = builtin.ConvolutionBackwardData
  329. class DeformableConvForward(OpNode):
  330. type = "DeformableConv"
  331. opdef = builtin.DeformableConv
  332. class GroupLocalForward(OpNode):
  333. type = "GroupLocal"
  334. opdef = builtin.GroupLocal
  335. class PoolingForward(OpNode):
  336. type = "Pooling"
  337. opdef = builtin.Pooling
  338. class AdaptivePoolingForward(OpNode):
  339. type = "AdaptivePooling"
  340. opdef = builtin.AdaptivePooling
  341. class ROIPoolingForward(OpNode):
  342. type = "ROIPooling"
  343. opdef = builtin.ROIPooling
  344. class DeformablePSROIPoolingForward(OpNode):
  345. type = "DeformablePSROIPooling"
  346. opdef = builtin.DeformablePSROIPooling
  347. class ConvBiasForward(OpNode):
  348. type = "ConvBias"
  349. opdef = builtin.ConvBias
  350. @classmethod
  351. def load(cls, opr):
  352. obj = super(ConvBiasForward, cls).load(opr)
  353. obj.params["dtype"] = opr.outputs[0].dtype
  354. return obj
  355. @register_flops(
  356. ConvolutionForward, ConvBiasForward,
  357. )
  358. def flops_conv(opnode: ConvolutionForward, inputs, outputs):
  359. param_W_shape = inputs[1].shape
  360. kh = param_W_shape[-2]
  361. kw = param_W_shape[-1]
  362. if len(param_W_shape) == 5:
  363. num_input = param_W_shape[2]
  364. else:
  365. num_input = param_W_shape[1]
  366. NCHW = np.prod(outputs[0].shape)
  367. bias = 1 if isinstance(opnode, ConvBiasForward) else 0
  368. # N x Cout x H x W x (Cin x Kw x Kh)
  369. return NCHW * (num_input * kw * kh + bias)
  370. @register_receptive_field(ConvolutionForward, ConvBiasForward)
  371. def receptive_field(opnode: ConvolutionForward, inputs, outputs):
  372. pre_rf, pre_stride = preprocess_receptive_field(opnode, inputs, outputs)
  373. param_W_shape = inputs[1].shape
  374. kh = param_W_shape[-2]
  375. kw = param_W_shape[-1]
  376. rf = (
  377. kh * pre_stride[0] + pre_rf[0] - pre_stride[0],
  378. kw * pre_stride[1] + pre_rf[1] - pre_stride[1],
  379. )
  380. stride = (
  381. opnode.params["stride_h"] * pre_stride[0],
  382. opnode.params["stride_w"] * pre_stride[1],
  383. )
  384. opnode._rf = rf
  385. opnode._stride = stride
  386. return rf, stride
  387. class BatchConvBiasForward(OpNode):
  388. type = "BatchConvBias"
  389. opdef = builtin.BatchConvBias
  390. @classmethod
  391. def load(cls, opr):
  392. obj = super(BatchConvBiasForward, cls).load(opr)
  393. obj.params["dtype"] = opr.outputs[0].dtype
  394. return obj
  395. class BatchNormForward(OpNode):
  396. type = "BatchNorm"
  397. opdef = builtin.BatchNorm
  398. output_idx = -1
  399. class ROIAlignForward(OpNode):
  400. type = "ROIAlign"
  401. opdef = builtin.ROIAlign
  402. class WarpPerspectiveForward(OpNode):
  403. type = "WarpPerspective"
  404. opdef = builtin.WarpPerspective
  405. class WarpAffineForward(OpNode):
  406. type = "WarpAffine"
  407. opdef = builtin.WarpAffine
  408. class RemapForward(OpNode):
  409. type = "Remap"
  410. opdef = builtin.Remap
  411. class ResizeForward(OpNode):
  412. type = "Resize"
  413. opdef = builtin.Resize
  414. class IndexingOneHot(OpNode):
  415. type = "IndexingOneHot"
  416. opdef = builtin.IndexingOneHot
  417. class IndexingSetOneHot(OpNode):
  418. type = "IndexingSetOneHot"
  419. opdef = builtin.IndexingSetOneHot
  420. class Copy(OpNode):
  421. type = "Copy"
  422. opdef = builtin.Copy
  423. @classmethod
  424. def load(cls, opr):
  425. obj = super(Copy, cls).load(opr)
  426. obj.params["comp_node"] = opr.outputs[0].comp_node
  427. return obj
  428. class ArgsortForward(OpNode):
  429. type = "Argsort"
  430. opdef = builtin.Argsort
  431. class Argmax(OpNode):
  432. type = "Argmax"
  433. opdef = builtin.Argmax
  434. class Argmin(OpNode):
  435. type = "Argmin"
  436. opdef = builtin.Argmin
  437. class CondTake(OpNode):
  438. type = "CondTake"
  439. opdef = builtin.CondTake
  440. class TopK(OpNode):
  441. type = "TopK"
  442. opdef = builtin.TopK
  443. class NvOf(OpNode):
  444. type = "NvOf"
  445. opdef = builtin.NvOf
  446. class RNGOpr(OpNode):
  447. @classmethod
  448. def load(cls, opr):
  449. obj = super(RNGOpr, cls).load(opr)
  450. if len(obj.params) == 3:
  451. obj.opdef = builtin.GaussianRNG
  452. obj.type = "GaussianRNG"
  453. else:
  454. obj.opdef = builtin.UniformRNG
  455. obj.type = "UniformRNG"
  456. return obj
  457. class Linspace(OpNode):
  458. type = "Linspace"
  459. opdef = builtin.Linspace
  460. @classmethod
  461. def load(cls, opr):
  462. obj = super(Linspace, cls).load(opr)
  463. obj.params["comp_node"] = opr.outputs[0].comp_node
  464. return obj
  465. class Eye(OpNode):
  466. type = "Eye"
  467. opdef = builtin.Eye
  468. @classmethod
  469. def load(cls, opr):
  470. obj = super(Eye, cls).load(opr)
  471. obj.params["dtype"] = opr.outputs[0].dtype
  472. obj.params["comp_node"] = opr.outputs[0].comp_node
  473. return obj
  474. class GetVarShape(OpNode):
  475. type = "GetVarShape"
  476. opdef = builtin.GetVarShape
  477. class Concat(OpNode):
  478. type = "Concat"
  479. opdef = builtin.Concat
  480. @classmethod
  481. def load(cls, opr):
  482. obj = super(Concat, cls).load(opr)
  483. obj.params["comp_node"] = Device("xpux").to_c()
  484. return obj
  485. class Broadcast(OpNode):
  486. type = "Broadcast"
  487. opdef = builtin.Broadcast
  488. class Identity(OpNode):
  489. type = "Identity"
  490. opdef = builtin.Identity
  491. class NMSKeep(OpNode):
  492. type = "NMSKeep"
  493. opdef = builtin.NMSKeep
  494. # class ParamPackSplit
  495. # class ParamPackConcat
  496. class Dimshuffle(OpNode):
  497. type = "Dimshuffle"
  498. opdef = builtin.Dimshuffle
  499. @classmethod
  500. def load(cls, opr):
  501. obj = super(Dimshuffle, cls).load(opr)
  502. del obj.params["ndim"]
  503. return obj
  504. class Reshape(OpNode):
  505. type = "Reshape"
  506. opdef = builtin.Reshape
  507. class AxisAddRemove(OpNode):
  508. type = "AxisAddRemove"
  509. @classmethod
  510. def load(cls, opr):
  511. obj = cls()
  512. obj.name = opr.name
  513. obj._opr = opr
  514. params = json.loads(opr.params)
  515. desc = params["desc"]
  516. method = None
  517. axis = []
  518. for i in desc:
  519. if method is None:
  520. method = i["method"]
  521. assert method == i["method"]
  522. axis.append(i["axisnum"])
  523. obj.params = {"axis": axis}
  524. obj.opdef = builtin.AddAxis if desc[0]["method"] == 0 else builtin.RemoveAxis
  525. return obj
  526. class IndexingBase(OpNode):
  527. @classmethod
  528. def load(cls, opr):
  529. obj = cls()
  530. obj.name = opr.name
  531. obj._opr = opr
  532. params = json.loads(opr.params)
  533. items = [
  534. [
  535. p["axis"],
  536. bool(p["begin"]),
  537. bool(p["end"]),
  538. bool(p["step"]),
  539. bool(p["idx"]),
  540. ]
  541. for p in params
  542. ]
  543. obj.params["items"] = items
  544. return obj
  545. class Subtensor(IndexingBase):
  546. type = "Subtensor"
  547. opdef = builtin.Subtensor
  548. class SetSubtensor(IndexingBase):
  549. type = "SetSubtensor"
  550. opdef = builtin.SetSubtensor
  551. class IncrSubtensor(IndexingBase):
  552. type = "IncrSubtensor"
  553. opdef = builtin.IncrSubtensor
  554. class IndexingMultiAxisVec(IndexingBase):
  555. type = "IndexingMultiAxisVec"
  556. opdef = builtin.IndexingMultiAxisVec
  557. class IndexingSetMultiAxisVec(IndexingBase):
  558. type = "IndexingSetMultiAxisVec"
  559. opdef = builtin.IndexingSetMultiAxisVec
  560. class IndexingIncrMultiAxisVec(IndexingBase):
  561. type = "IndexingIncrMultiAxisVec"
  562. opdef = builtin.IndexingIncrMultiAxisVec
  563. class MeshIndexing(IndexingBase):
  564. type = "MeshIndexing"
  565. opdef = builtin.MeshIndexing
  566. class SetMeshIndexing(IndexingBase):
  567. type = "SetMeshIndexing"
  568. opdef = builtin.SetMeshIndexing
  569. class IncrMeshIndexing(IndexingBase):
  570. type = "IncrMeshIndexing"
  571. opdef = builtin.IncrMeshIndexing
  572. class BatchedMeshIndexing(IndexingBase):
  573. type = "BatchedMeshIndexing"
  574. opdef = builtin.BatchedMeshIndexing
  575. class BatchedSetMeshIndexing(IndexingBase):
  576. type = "BatchedSetMeshIndexing"
  577. opdef = builtin.BatchedSetMeshIndexing
  578. class BatchedIncrMeshIndexing(IndexingBase):
  579. type = "BatchedIncrMeshIndexing"
  580. opdef = builtin.BatchedIncrMeshIndexing
  581. # class CollectiveComm
  582. # class RemoteSend
  583. # class RemoteRecv
  584. # class TQT
  585. # class FakeQuant
  586. # class InplaceAdd
  587. class AssertEqual(OpNode):
  588. type = "AssertEqual"
  589. opdef = builtin.AssertEqual
  590. class CvtColorForward(OpNode):
  591. type = "CvtColor"
  592. opdef = builtin.CvtColor

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台