You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_node.py 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import json
  10. import sys
  11. from typing import Sequence
  12. import numpy as np
  13. from ..core import _imperative_rt as rt
  14. from ..core._imperative_rt.core2 import SymbolVar, apply
  15. from ..core._trace_option import use_symbolic_shape
  16. from ..core._wrap import Device
  17. from ..core.ops import builtin
  18. from ..core.tensor.array_method import ArrayMethodMixin
  19. from .comp_graph_tools import replace_vars
  20. from .module_stats import (
  21. preprocess_receptive_field,
  22. register_flops,
  23. register_receptive_field,
  24. )
  25. class NetworkNode:
  26. pass
  27. class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)):
  28. pass
  29. class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
  30. def __init__(self, var=None, *, owner_opr=None, name=None):
  31. SymbolVar.__init__(self, var)
  32. self.owner = owner_opr
  33. self.name = name
  34. self.id = id(self)
  35. @classmethod
  36. def load(cls, sym_var, owner_opr):
  37. obj = cls()
  38. obj.var = sym_var # mgb varnode
  39. obj.name = sym_var.name
  40. obj.owner = owner_opr
  41. return obj
  42. def _get_var_shape(self, axis=None):
  43. opdef = (
  44. builtin.GetVarShape() if axis is None else builtin.GetVarShape(axis=axis)
  45. )
  46. return apply(opdef, self)[0]
  47. @property
  48. def partial_shape(self):
  49. """Return the tuple type inferred shape of VarNode
  50. """
  51. return tuple(self._get_var_shape().numpy())
  52. def shapeof(self, axis):
  53. """Return the symbolic shape of axis
  54. """
  55. return self._get_var_shape(axis=axis) if self.var else None
  56. @property
  57. def _tuple_shape(self):
  58. return self.partial_shape
  59. @property
  60. def shape(self):
  61. """Return the symbolic shape if using set_symbolic_shape(True)
  62. else inferred shape
  63. """
  64. rst = None
  65. if self.var:
  66. try:
  67. rst = self.var.shape
  68. except:
  69. rst = None
  70. if not use_symbolic_shape():
  71. return rst
  72. return self._get_var_shape() if self.var else None
  73. @property
  74. def dtype(self):
  75. return self.var.dtype if self.var else None
  76. @property
  77. def ndim(self):
  78. return super().ndim
  79. def __bool__(self):
  80. return False
  81. __index__ = None
  82. __int__ = None
  83. __float__ = None
  84. __complex__ = None
  85. def __hash__(self):
  86. return id(self)
  87. def numpy(self):
  88. return super().numpy()
  89. def _reset(self, other):
  90. if not isinstance(other, VarNode):
  91. assert self.graph, "VarNode _reset must have graph"
  92. node = ImmutableTensor(other, graph=self.graph)
  93. node.compile(self.graph)
  94. other = node.outputs[0]
  95. if self.owner is not None:
  96. idx = self.owner.outputs.index(self)
  97. self.owner.outputs[idx] = VarNode(
  98. self.var, owner_opr=self.owner, name=self.var.name
  99. )
  100. self.var = other.var
  101. self.owner = None
  102. def set_owner_opr(self, owner_opr):
  103. self.owner = owner_opr
  104. class OpNode(NetworkNode):
  105. opdef = None
  106. type = None
  107. def __init__(self):
  108. self.inputs = []
  109. self.outputs = []
  110. self.params = {}
  111. self._opr = None # mgb opnode
  112. @property
  113. def id(self):
  114. return id(self)
  115. @property
  116. def priority(self):
  117. if self._opr is not None:
  118. return (self._opr.priority, self._opr.id)
  119. return (0, 0)
  120. @classmethod
  121. def load(cls, opr):
  122. obj = cls()
  123. obj.params = json.loads(opr.params)
  124. obj.name = opr.name
  125. obj._opr = opr
  126. return obj
  127. def compile(self):
  128. if (
  129. self._opr is None
  130. or len(self._opr.inputs) != len(self.inputs)
  131. or any([i != j.var for i, j in zip(self._opr.inputs, self.inputs)])
  132. ):
  133. op = self.opdef(**self.params)
  134. args = [i.var for i in self.inputs]
  135. outputs = rt.invoke_op(op, args)
  136. assert len(outputs) == len(self.outputs)
  137. self._opr = outputs[0].owner
  138. for i in range(len(self.outputs)):
  139. self.outputs[i].var = outputs[i]
  140. self.outputs[i].var.name = self.outputs[i].name
  141. assert self.outputs[i].owner is self
  142. def add_inp_var(self, x):
  143. self.inputs.append(x)
  144. def add_out_var(self, x):
  145. self.outputs.append(x)
  146. def __repr__(self):
  147. return "%s{%s}" % (self.name, self.type)
  148. def str_to_mge_class(classname):
  149. # TODO: use megbrain C++ RTTI to replace type string
  150. if classname == "RNGOpr<MegDNNOpr>":
  151. classname = "RNGOpr"
  152. oprcls = getattr(sys.modules[__name__], classname, None)
  153. return oprcls if oprcls else ReadOnlyOpNode
  154. class Host2DeviceCopy(OpNode):
  155. type = "Host2DeviceCopy"
  156. def __init__(self, shape=None, dtype=None, name=None, device=None):
  157. super().__init__()
  158. self.shape = shape
  159. self.dtype = dtype
  160. self.name = name
  161. self.device = Device(device).to_c() if device else Device("xpux").to_c()
  162. self.outputs = []
  163. @classmethod
  164. def load(cls, opr):
  165. self = cls()
  166. self.outputs = []
  167. assert len(opr.outputs) == 1, "wrong number of outputs"
  168. self.shape = opr.outputs[0].shape
  169. self.dtype = opr.outputs[0].dtype
  170. self.name = opr.outputs[0].name
  171. self.device = opr.outputs[0].comp_node
  172. self._opr = opr
  173. return self
  174. def compile(self, graph):
  175. if (
  176. self._opr is None
  177. or self._opr.outputs[0].comp_node != self.device
  178. or self._opr.outputs[0].shape != self.shape
  179. or self._opr.outputs[0].dtype != self.dtype
  180. ):
  181. outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name)
  182. self._opr = outputs.owner
  183. if len(self.outputs) == 0:
  184. self.outputs.append(VarNode(owner_opr=self, name=self.name))
  185. self.outputs[0].var = outputs
  186. assert self.outputs[0].owner is self
  187. class ImmutableTensor(OpNode):
  188. type = "ImmutableTensor"
  189. def __init__(self, data=None, name=None, device=None, graph=None):
  190. super().__init__()
  191. self.name = name
  192. self.outputs = []
  193. self.graph = graph
  194. if data is not None:
  195. self.set_value(data, device)
  196. @property
  197. def device(self):
  198. return self._opr.outputs[0].comp_node if self._opr else None
  199. @device.setter
  200. def device(self, device):
  201. self.set_value(self.numpy(), device)
  202. @property
  203. def shape(self):
  204. return self.outputs[0].shape
  205. @property
  206. def dtype(self):
  207. return self._opr.outputs[0].dtype if self._opr else None
  208. def numpy(self):
  209. return self._opr.outputs[0].value if self._opr else None
  210. def set_value(self, data, device=None):
  211. assert self.graph is not None
  212. cn = device if device else self.device
  213. assert isinstance(data, (int, float, Sequence, np.ndarray))
  214. if not isinstance(data, np.ndarray):
  215. data = np.array(data)
  216. if data.dtype == np.float64:
  217. data = data.astype(np.float32)
  218. elif data.dtype == np.int64:
  219. data = data.astype(np.int32)
  220. varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name)
  221. if len(self.outputs) == 0:
  222. self.outputs.append(VarNode(owner_opr=self, name=self.name))
  223. self.outputs[0].var = varnode
  224. self._opr = varnode.owner
  225. @classmethod
  226. def load(cls, opr):
  227. self = cls()
  228. self.outputs = []
  229. self._opr = opr
  230. self.name = opr.outputs[0].name
  231. self.graph = opr.graph
  232. return self
  233. def compile(self, graph):
  234. assert self.outputs[0].var is self._opr.outputs[0]
  235. assert self.outputs[0].owner is self
  236. if self.graph != graph:
  237. self.graph = graph
  238. self.set_value(self.numpy())
  239. if self.name is not None:
  240. self.outputs[0].var.name = self.name
  241. class ReadOnlyOpNode(OpNode):
  242. @classmethod
  243. def load(cls, opr):
  244. obj = super(ReadOnlyOpNode, cls).load(opr)
  245. obj.type = opr.type
  246. return obj
  247. def compile(self):
  248. assert self._opr is not None
  249. assert len(self.inputs) == len(self._opr.inputs)
  250. assert len(self.outputs) == len(self._opr.outputs)
  251. repl_dict = {}
  252. for ind, i in enumerate(self.inputs):
  253. if i.var != self._opr.inputs[ind]:
  254. repl_dict[self._opr.inputs[ind]] = i.var
  255. if bool(repl_dict):
  256. out_vars = replace_vars(self._opr.outputs, repl_dict)
  257. for ind, o in enumerate(self.outputs):
  258. o.var = out_vars[ind]
  259. class Elemwise(OpNode):
  260. type = "Elemwise"
  261. opdef = builtin.Elemwise
  262. def __repr__(self):
  263. return "%s{Elemwise:%s}" % (self.name, self.params["mode"])
  264. class ElemwiseMultiType(OpNode):
  265. type = "ElemwiseMultiType"
  266. opdef = builtin.ElemwiseMultiType
  267. def __repr__(self):
  268. return "%s{ElemwiseMultiType:%s}" % (self.name, self.params["mode"])
  269. @classmethod
  270. def load(cls, opr):
  271. obj = super(ElemwiseMultiType, cls).load(opr)
  272. obj.params["dtype"] = opr.outputs[0].dtype
  273. return obj
  274. @register_flops(Elemwise, ElemwiseMultiType)
  275. def flops_elemwise(opnode: Elemwise, inputs, outputs):
  276. return np.prod(outputs[0].shape)
  277. class Reduce(OpNode):
  278. type = "Reduce"
  279. opdef = builtin.Reduce
  280. class TypeCvt(OpNode):
  281. type = "TypeCvt"
  282. opdef = builtin.TypeCvt
  283. @classmethod
  284. def load(cls, opr):
  285. obj = super(TypeCvt, cls).load(opr)
  286. t_dtype = opr.outputs[0].dtype
  287. obj.params["dtype"] = t_dtype
  288. return obj
  289. class MatrixInverse(OpNode):
  290. type = "MatrixInverse"
  291. opdef = builtin.MatrixInverse
  292. class MatrixMul(OpNode):
  293. type = "MatrixMul"
  294. opdef = builtin.MatrixMul
  295. @register_flops(MatrixMul)
  296. def flops_matmul(opnode: MatrixMul, inputs, outputs):
  297. assert len(inputs[0].shape) == 2 and len(outputs[0].shape) == 2
  298. mid_shape = inputs[0].shape[1]
  299. return np.prod(outputs[0].shape) * mid_shape
  300. class BatchedMatrixMul(OpNode):
  301. type = "BatchedMatmul"
  302. opdef = builtin.BatchedMatrixMul
  303. @register_flops(BatchedMatrixMul)
  304. def flops_batchmatmul(opnode: BatchedMatrixMul, inputs, outputs):
  305. assert len(inputs[0].shape) == 3 and len(outputs[0].shape) == 3
  306. mid_shape = inputs[0].shape[2]
  307. return np.prod(outputs[0].shape) * mid_shape
  308. class Dot(OpNode):
  309. type = "Dot"
  310. opdef = builtin.Dot
  311. class SVD(OpNode):
  312. type = "SVD"
  313. opdef = builtin.SVD
  314. class ConvolutionForward(OpNode):
  315. type = "Convolution"
  316. opdef = builtin.Convolution
  317. class ConvolutionBackwardData(OpNode):
  318. type = "ConvTranspose"
  319. opdef = builtin.ConvolutionBackwardData
  320. class DeformableConvForward(OpNode):
  321. type = "DeformableConv"
  322. opdef = builtin.DeformableConv
  323. class GroupLocalForward(OpNode):
  324. type = "GroupLocal"
  325. opdef = builtin.GroupLocal
  326. class PoolingForward(OpNode):
  327. type = "Pooling"
  328. opdef = builtin.Pooling
  329. class AdaptivePoolingForward(OpNode):
  330. type = "AdaptivePooling"
  331. opdef = builtin.AdaptivePooling
  332. class ROIPoolingForward(OpNode):
  333. type = "ROIPooling"
  334. opdef = builtin.ROIPooling
  335. class DeformablePSROIPoolingForward(OpNode):
  336. type = "DeformablePSROIPooling"
  337. opdef = builtin.DeformablePSROIPooling
  338. class ConvBiasForward(OpNode):
  339. type = "ConvBias"
  340. opdef = builtin.ConvBias
  341. @classmethod
  342. def load(cls, opr):
  343. obj = super(ConvBiasForward, cls).load(opr)
  344. obj.params["dtype"] = opr.outputs[0].dtype
  345. return obj
  346. @register_flops(
  347. ConvolutionForward, ConvBiasForward,
  348. )
  349. def flops_conv(opnode: ConvolutionForward, inputs, outputs):
  350. param_W_shape = inputs[1].shape
  351. kh = param_W_shape[-2]
  352. kw = param_W_shape[-1]
  353. if len(param_W_shape) == 5:
  354. num_input = param_W_shape[2]
  355. else:
  356. num_input = param_W_shape[1]
  357. NCHW = np.prod(outputs[0].shape)
  358. bias = 1 if isinstance(opnode, ConvBiasForward) else 0
  359. # N x Cout x H x W x (Cin x Kw x Kh)
  360. return NCHW * (num_input * kw * kh + bias)
  361. @register_receptive_field(ConvolutionForward, ConvBiasForward)
  362. def receptive_field(opnode: ConvolutionForward, inputs, outputs):
  363. pre_rf, pre_stride = preprocess_receptive_field(opnode, inputs, outputs)
  364. param_W_shape = inputs[1].shape
  365. kh = param_W_shape[-2]
  366. kw = param_W_shape[-1]
  367. rf = (
  368. kh * pre_stride[0] + pre_rf[0] - pre_stride[0],
  369. kw * pre_stride[1] + pre_rf[1] - pre_stride[1],
  370. )
  371. stride = (
  372. opnode.params["stride_h"] * pre_stride[0],
  373. opnode.params["stride_w"] * pre_stride[1],
  374. )
  375. opnode._rf = rf
  376. opnode._stride = stride
  377. return rf, stride
  378. class BatchConvBiasForward(OpNode):
  379. type = "BatchConvBias"
  380. opdef = builtin.BatchConvBias
  381. @classmethod
  382. def load(cls, opr):
  383. obj = super(BatchConvBiasForward, cls).load(opr)
  384. obj.params["dtype"] = opr.outputs[0].dtype
  385. return obj
  386. class BatchNormForward(OpNode):
  387. type = "BatchNorm"
  388. opdef = builtin.BatchNorm
  389. output_idx = -1
  390. class ROIAlignForward(OpNode):
  391. type = "ROIAlign"
  392. opdef = builtin.ROIAlign
  393. class WarpPerspectiveForward(OpNode):
  394. type = "WarpPerspective"
  395. opdef = builtin.WarpPerspective
  396. class WarpAffineForward(OpNode):
  397. type = "WarpAffine"
  398. opdef = builtin.WarpAffine
  399. class RemapForward(OpNode):
  400. type = "Remap"
  401. opdef = builtin.Remap
  402. class ResizeForward(OpNode):
  403. type = "Resize"
  404. opdef = builtin.Resize
  405. class IndexingOneHot(OpNode):
  406. type = "IndexingOneHot"
  407. opdef = builtin.IndexingOneHot
  408. class IndexingSetOneHot(OpNode):
  409. type = "IndexingSetOneHot"
  410. opdef = builtin.IndexingSetOneHot
  411. class Copy(OpNode):
  412. type = "Copy"
  413. opdef = builtin.Copy
  414. @classmethod
  415. def load(cls, opr):
  416. obj = super(Copy, cls).load(opr)
  417. obj.params["comp_node"] = opr.outputs[0].comp_node
  418. return obj
  419. class ArgsortForward(OpNode):
  420. type = "Argsort"
  421. opdef = builtin.Argsort
  422. class Argmax(OpNode):
  423. type = "Argmax"
  424. opdef = builtin.Argmax
  425. class Argmin(OpNode):
  426. type = "Argmin"
  427. opdef = builtin.Argmin
  428. class CondTake(OpNode):
  429. type = "CondTake"
  430. opdef = builtin.CondTake
  431. class TopK(OpNode):
  432. type = "TopK"
  433. opdef = builtin.TopK
  434. class NvOf(OpNode):
  435. type = "NvOf"
  436. opdef = builtin.NvOf
  437. class RNGOpr(OpNode):
  438. @classmethod
  439. def load(cls, opr):
  440. obj = super(RNGOpr, cls).load(opr)
  441. if len(obj.params) == 3:
  442. obj.opdef = builtin.GaussianRNG
  443. obj.type = "GaussianRNG"
  444. else:
  445. obj.opdef = builtin.UniformRNG
  446. obj.type = "UniformRNG"
  447. return obj
  448. class Linspace(OpNode):
  449. type = "Linspace"
  450. opdef = builtin.Linspace
  451. @classmethod
  452. def load(cls, opr):
  453. obj = super(Linspace, cls).load(opr)
  454. obj.params["comp_node"] = opr.outputs[0].comp_node
  455. return obj
  456. class Eye(OpNode):
  457. type = "Eye"
  458. opdef = builtin.Eye
  459. @classmethod
  460. def load(cls, opr):
  461. obj = super(Eye, cls).load(opr)
  462. obj.params["dtype"] = opr.outputs[0].dtype
  463. obj.params["comp_node"] = opr.outputs[0].comp_node
  464. return obj
  465. class GetVarShape(OpNode):
  466. type = "GetVarShape"
  467. opdef = builtin.GetVarShape
  468. class Concat(OpNode):
  469. type = "Concat"
  470. opdef = builtin.Concat
  471. @classmethod
  472. def load(cls, opr):
  473. obj = super(Concat, cls).load(opr)
  474. obj.params["comp_node"] = Device("xpux").to_c()
  475. return obj
  476. class Broadcast(OpNode):
  477. type = "Broadcast"
  478. opdef = builtin.Broadcast
  479. class Identity(OpNode):
  480. type = "Identity"
  481. opdef = builtin.Identity
  482. class NMSKeep(OpNode):
  483. type = "NMSKeep"
  484. opdef = builtin.NMSKeep
  485. # class ParamPackSplit
  486. # class ParamPackConcat
  487. class Dimshuffle(OpNode):
  488. type = "Dimshuffle"
  489. opdef = builtin.Dimshuffle
  490. @classmethod
  491. def load(cls, opr):
  492. obj = super(Dimshuffle, cls).load(opr)
  493. del obj.params["ndim"]
  494. return obj
  495. class Reshape(OpNode):
  496. type = "Reshape"
  497. opdef = builtin.Reshape
  498. class AxisAddRemove(OpNode):
  499. type = "AxisAddRemove"
  500. @classmethod
  501. def load(cls, opr):
  502. obj = cls()
  503. obj.name = opr.name
  504. obj._opr = opr
  505. params = json.loads(opr.params)
  506. desc = params["desc"]
  507. method = None
  508. axis = []
  509. for i in desc:
  510. if method is None:
  511. method = i["method"]
  512. assert method == i["method"]
  513. axis.append(i["axisnum"])
  514. obj.params = {"axis": axis}
  515. obj.opdef = builtin.AddAxis if desc[0]["method"] == 0 else builtin.RemoveAxis
  516. return obj
  517. class IndexingBase(OpNode):
  518. @classmethod
  519. def load(cls, opr):
  520. obj = cls()
  521. obj.name = opr.name
  522. obj._opr = opr
  523. params = json.loads(opr.params)
  524. items = [
  525. [
  526. p["axis"],
  527. bool(p["begin"]),
  528. bool(p["end"]),
  529. bool(p["step"]),
  530. bool(p["idx"]),
  531. ]
  532. for p in params
  533. ]
  534. obj.params["items"] = items
  535. return obj
  536. class Subtensor(IndexingBase):
  537. type = "Subtensor"
  538. opdef = builtin.Subtensor
  539. class SetSubtensor(IndexingBase):
  540. type = "SetSubtensor"
  541. opdef = builtin.SetSubtensor
  542. class IncrSubtensor(IndexingBase):
  543. type = "IncrSubtensor"
  544. opdef = builtin.IncrSubtensor
  545. class IndexingMultiAxisVec(IndexingBase):
  546. type = "IndexingMultiAxisVec"
  547. opdef = builtin.IndexingMultiAxisVec
  548. class IndexingSetMultiAxisVec(IndexingBase):
  549. type = "IndexingSetMultiAxisVec"
  550. opdef = builtin.IndexingSetMultiAxisVec
  551. class IndexingIncrMultiAxisVec(IndexingBase):
  552. type = "IndexingIncrMultiAxisVec"
  553. opdef = builtin.IndexingIncrMultiAxisVec
  554. class MeshIndexing(IndexingBase):
  555. type = "MeshIndexing"
  556. opdef = builtin.MeshIndexing
  557. class SetMeshIndexing(IndexingBase):
  558. type = "SetMeshIndexing"
  559. opdef = builtin.SetMeshIndexing
  560. class IncrMeshIndexing(IndexingBase):
  561. type = "IncrMeshIndexing"
  562. opdef = builtin.IncrMeshIndexing
  563. class BatchedMeshIndexing(IndexingBase):
  564. type = "BatchedMeshIndexing"
  565. opdef = builtin.BatchedMeshIndexing
  566. class BatchedSetMeshIndexing(IndexingBase):
  567. type = "BatchedSetMeshIndexing"
  568. opdef = builtin.BatchedSetMeshIndexing
  569. class BatchedIncrMeshIndexing(IndexingBase):
  570. type = "BatchedIncrMeshIndexing"
  571. opdef = builtin.BatchedIncrMeshIndexing
  572. # class CollectiveComm
  573. # class RemoteSend
  574. # class RemoteRecv
  575. # class TQT
  576. # class FakeQuant
  577. # class InplaceAdd
  578. class AssertEqual(OpNode):
  579. type = "AssertEqual"
  580. opdef = builtin.AssertEqual
  581. class CvtColorForward(OpNode):
  582. type = "CvtColor"
  583. opdef = builtin.CvtColor

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台