You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_node.py 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import abc
  10. import json
  11. import sys
  12. from typing import Callable, Sequence
  13. import numpy as np
  14. from ..core import _imperative_rt as rt
  15. from ..core._imperative_rt.core2 import SymbolVar, apply
  16. from ..core._trace_option import use_symbolic_shape
  17. from ..core._wrap import Device
  18. from ..core.ops import builtin
  19. from ..core.tensor.array_method import ArrayMethodMixin
  20. from ..core.tensor.indexing import getitem as _getitem
  21. from ..core.tensor.indexing import setitem as _setitem
  22. from ..core.tensor.megbrain_graph import InputNode, OutputNode
  23. from ..tensor import Tensor
  24. from .comp_graph_tools import replace_vars
  25. from .module_stats import (
  26. preprocess_receptive_field,
  27. register_flops,
  28. register_receptive_field,
  29. )
  30. class NetworkNode:
  31. pass
  32. class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)):
  33. pass
  34. class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
  35. def __init__(self, var=None, *, owner_opr=None, name=None):
  36. SymbolVar.__init__(self, var)
  37. self.owner = owner_opr
  38. self.name = name
  39. self.id = id(self)
  40. @classmethod
  41. def load(cls, sym_var, owner_opr):
  42. obj = cls()
  43. obj.var = sym_var # mgb varnode
  44. obj.name = sym_var.name
  45. obj.owner = owner_opr
  46. return obj
  47. def _get_var_shape(self, axis=None):
  48. opdef = (
  49. builtin.GetVarShape() if axis is None else builtin.GetVarShape(axis=axis)
  50. )
  51. return apply(opdef, self)[0]
  52. @property
  53. def partial_shape(self):
  54. """Return the tuple type inferred shape of VarNode
  55. """
  56. return tuple(self._get_var_shape().numpy())
  57. def shapeof(self, axis):
  58. """Return the symbolic shape of axis
  59. """
  60. return self._get_var_shape(axis=axis) if self.var else None
  61. @property
  62. def _tuple_shape(self):
  63. return self.partial_shape
  64. @property
  65. def shape(self):
  66. """Return the symbolic shape if using set_symbolic_shape(True)
  67. else inferred shape
  68. """
  69. rst = None
  70. if self.var:
  71. try:
  72. rst = self.var.shape
  73. except:
  74. rst = None
  75. if not use_symbolic_shape():
  76. return rst
  77. return self._get_var_shape() if self.var else None
  78. @property
  79. def dtype(self):
  80. return self.var.dtype if self.var else None
  81. def __bool__(self):
  82. return False
  83. __index__ = None
  84. __int__ = None
  85. __float__ = None
  86. __complex__ = None
  87. def __hash__(self):
  88. return id(self)
  89. def numpy(self):
  90. o = OutputNode(self.var)
  91. self.graph.compile(o.outputs).execute()
  92. return o.get_value().numpy()
  93. def __getitem__(self, index):
  94. return _getitem(self, index)
  95. def __setitem__(self, index, value):
  96. if index is not Ellipsis:
  97. value = _setitem(self, index, value)
  98. if self.owner is not None:
  99. idx = self.owner.outputs.index(self)
  100. self.owner.outputs[idx] = VarNode(
  101. self.var, owner_opr=self.owner, name=self.var.name
  102. )
  103. self.var = value.var
  104. self.owner = None
  105. def set_owner_opr(self, owner_opr):
  106. self.owner = owner_opr
  107. class OpNode(NetworkNode):
  108. opdef = None
  109. type = None
  110. def __init__(self):
  111. self.inputs = []
  112. self.outputs = []
  113. self.params = {}
  114. self._opr = None # mgb opnode
  115. self.id = id(self)
  116. @classmethod
  117. def load(cls, opr):
  118. obj = cls()
  119. obj.params = json.loads(opr.params)
  120. obj.name = opr.name
  121. obj._opr = opr
  122. return obj
  123. def compile(self, graph=None):
  124. op = self.opdef(**self.params)
  125. args = [i.var for i in self.inputs]
  126. outputs = rt.invoke_op(op, args)
  127. assert len(outputs) == len(self.outputs)
  128. self._opr = outputs[0].owner
  129. for i in range(len(self.outputs)):
  130. self.outputs[i].var = outputs[i]
  131. self.outputs[i].var.name = self.outputs[i].name
  132. assert self.outputs[i].owner is self
  133. def add_inp_var(self, x):
  134. self.inputs.append(x)
  135. def add_out_var(self, x):
  136. self.outputs.append(x)
  137. def __repr__(self):
  138. return "%s{%s}" % (self.name, self.type)
  139. def str_to_mge_class(classname):
  140. # TODO: use megbrain C++ RTTI to replace type string
  141. if classname == "RNGOpr<MegDNNOpr>":
  142. classname = "RNGOpr"
  143. oprcls = getattr(sys.modules[__name__], classname, None)
  144. return oprcls if oprcls else ReadOnlyOpNode
  145. class Host2DeviceCopy(OpNode):
  146. type = "Host2DeviceCopy"
  147. def __init__(self, shape=None, dtype=None, name=None, device=None):
  148. super().__init__()
  149. self.shape = shape
  150. self.dtype = dtype
  151. self.name = name
  152. self.device = Device(device).to_c() if device else Device("xpux").to_c()
  153. self.outputs = []
  154. @classmethod
  155. def load(cls, opr):
  156. self = cls()
  157. self.outputs = []
  158. assert len(opr.outputs) == 1, "wrong number of outputs"
  159. self.shape = opr.outputs[0].shape
  160. self.dtype = opr.outputs[0].dtype
  161. self.name = opr.outputs[0].name
  162. self.device = opr.outputs[0].comp_node
  163. self._opr = opr
  164. return self
  165. def compile(self, graph):
  166. outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name)
  167. self._opr = outputs.owner
  168. if len(self.outputs) == 0:
  169. self.outputs.append(VarNode(owner_opr=self, name=self.name))
  170. self.outputs[0].var = outputs
  171. assert self.outputs[0].owner is self
  172. class ImmutableTensor(OpNode):
  173. type = "ImmutableTensor"
  174. def __init__(self, data=None, name=None, device=None, graph=None):
  175. super().__init__()
  176. self.name = name
  177. self.outputs = []
  178. self.graph = graph
  179. if data is not None:
  180. self.set_value(data, device)
  181. @property
  182. def device(self):
  183. return self._opr.outputs[0].comp_node if self._opr else None
  184. @device.setter
  185. def device(self, device):
  186. self.set_value(self.numpy(), device)
  187. @property
  188. def shape(self):
  189. return self.outputs[0].shape
  190. @property
  191. def dtype(self):
  192. return self._opr.outputs[0].dtype if self._opr else None
  193. def numpy(self):
  194. return self._opr.outputs[0].value if self._opr else None
  195. def set_value(self, data, device=None):
  196. assert self.graph is not None
  197. cn = device if device else self.device
  198. assert isinstance(data, (int, float, Sequence, np.ndarray))
  199. if not isinstance(data, np.ndarray):
  200. data = np.array(data)
  201. if data.dtype == np.float64:
  202. data = data.astype(np.float32)
  203. elif data.dtype == np.int64:
  204. data = data.astype(np.int32)
  205. varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name)
  206. if len(self.outputs) == 0:
  207. self.outputs.append(VarNode(owner_opr=self, name=self.name))
  208. self.outputs[0].var = varnode
  209. self._opr = varnode.owner
  210. @classmethod
  211. def load(cls, opr):
  212. self = cls()
  213. self.outputs = []
  214. self._opr = opr
  215. self.name = opr.outputs[0].name
  216. self.graph = opr.graph
  217. return self
  218. def compile(self, graph):
  219. assert self.outputs[0].var is self._opr.outputs[0]
  220. assert self.outputs[0].owner is self
  221. if self.graph != graph:
  222. self.graph = graph
  223. self.set_value(self.numpy())
  224. if self.name is not None:
  225. self.outputs[0].var.name = self.name
  226. class ReadOnlyOpNode(OpNode):
  227. @classmethod
  228. def load(cls, opr):
  229. obj = super(ReadOnlyOpNode, cls).load(opr)
  230. obj.type = opr.type
  231. return obj
  232. def compile(self):
  233. assert self._opr is not None
  234. assert len(self.inputs) == len(self._opr.inputs)
  235. assert len(self.outputs) == len(self._opr.outputs)
  236. repl_dict = {}
  237. for ind, i in enumerate(self.inputs):
  238. if i.var != self._opr.inputs[ind]:
  239. repl_dict[self._opr.inputs[ind]] = i.var
  240. if bool(repl_dict):
  241. out_vars = replace_vars(self._opr.outputs, repl_dict)
  242. for ind, o in enumerate(self.outputs):
  243. o.var = out_vars[ind]
  244. class Elemwise(OpNode):
  245. type = "Elemwise"
  246. opdef = builtin.Elemwise
  247. def __repr__(self):
  248. return "%s{Elemwise:%s}" % (self.name, self.params["mode"])
  249. class ElemwiseMultiType(OpNode):
  250. type = "ElemwiseMultiType"
  251. opdef = builtin.ElemwiseMultiType
  252. def __repr__(self):
  253. return "%s{ElemwiseMultiType:%s}" % (self.name, self.params["mode"])
  254. @classmethod
  255. def load(cls, opr):
  256. obj = super(ElemwiseMultiType, cls).load(opr)
  257. obj.params["dtype"] = opr.outputs[0].dtype
  258. return obj
  259. @register_flops(Elemwise, ElemwiseMultiType)
  260. def flops_elemwise(opnode: Elemwise, inputs, outputs):
  261. return np.prod(outputs[0].shape)
  262. class Reduce(OpNode):
  263. type = "Reduce"
  264. opdef = builtin.Reduce
  265. class TypeCvt(OpNode):
  266. type = "TypeCvt"
  267. opdef = builtin.TypeCvt
  268. @classmethod
  269. def load(cls, opr):
  270. obj = super(TypeCvt, cls).load(opr)
  271. t_dtype = opr.outputs[0].dtype
  272. obj.params["dtype"] = t_dtype
  273. return obj
  274. class MatrixInverse(OpNode):
  275. type = "MatrixInverse"
  276. opdef = builtin.MatrixInverse
  277. class MatrixMul(OpNode):
  278. type = "MatrixMul"
  279. opdef = builtin.MatrixMul
  280. @register_flops(MatrixMul)
  281. def flops_matmul(opnode: MatrixMul, inputs, outputs):
  282. assert len(inputs[0].shape) == 2 and len(outputs[0].shape) == 2
  283. mid_shape = inputs[0].shape[1]
  284. return np.prod(outputs[0].shape) * mid_shape
  285. class BatchedMatrixMul(OpNode):
  286. type = "BatchedMatmul"
  287. opdef = builtin.BatchedMatrixMul
  288. @register_flops(BatchedMatrixMul)
  289. def flops_batchmatmul(opnode: BatchedMatrixMul, inputs, outputs):
  290. assert len(inputs[0].shape) == 3 and len(outputs[0].shape) == 3
  291. mid_shape = inputs[0].shape[2]
  292. return np.prod(outputs[0].shape) * mid_shape
  293. class Dot(OpNode):
  294. type = "Dot"
  295. opdef = builtin.Dot
  296. class SVD(OpNode):
  297. type = "SVD"
  298. opdef = builtin.SVD
  299. class ConvolutionForward(OpNode):
  300. type = "Convolution"
  301. opdef = builtin.Convolution
  302. class ConvolutionBackwardData(OpNode):
  303. type = "ConvTranspose"
  304. opdef = builtin.ConvolutionBackwardData
  305. class DeformableConvForward(OpNode):
  306. type = "DeformableConv"
  307. opdef = builtin.DeformableConv
  308. class GroupLocalForward(OpNode):
  309. type = "GroupLocal"
  310. opdef = builtin.GroupLocal
  311. class PoolingForward(OpNode):
  312. type = "Pooling"
  313. opdef = builtin.Pooling
  314. class AdaptivePoolingForward(OpNode):
  315. type = "AdaptivePooling"
  316. opdef = builtin.AdaptivePooling
  317. class ROIPoolingForward(OpNode):
  318. type = "ROIPooling"
  319. opdef = builtin.ROIPooling
  320. class DeformablePSROIPoolingForward(OpNode):
  321. type = "DeformablePSROIPooling"
  322. opdef = builtin.DeformablePSROIPooling
  323. class ConvBiasForward(OpNode):
  324. type = "ConvBias"
  325. opdef = builtin.ConvBias
  326. @classmethod
  327. def load(cls, opr):
  328. obj = super(ConvBiasForward, cls).load(opr)
  329. obj.params["dtype"] = opr.outputs[0].dtype
  330. return obj
  331. @register_flops(
  332. ConvolutionForward, ConvBiasForward,
  333. )
  334. def flops_conv(opnode: ConvolutionForward, inputs, outputs):
  335. param_W_shape = inputs[1].shape
  336. kh = param_W_shape[-2]
  337. kw = param_W_shape[-1]
  338. if len(param_W_shape) == 5:
  339. num_input = param_W_shape[2]
  340. else:
  341. num_input = param_W_shape[1]
  342. NCHW = np.prod(outputs[0].shape)
  343. bias = 1 if isinstance(opnode, ConvBiasForward) else 0
  344. # N x Cout x H x W x (Cin x Kw x Kh)
  345. return NCHW * (num_input * kw * kh + bias)
  346. @register_receptive_field(ConvolutionForward, ConvBiasForward)
  347. def receptive_field(opnode: ConvolutionForward, inputs, outputs):
  348. pre_rf, pre_stride = preprocess_receptive_field(opnode, inputs, outputs)
  349. param_W_shape = inputs[1].shape
  350. kh = param_W_shape[-2]
  351. kw = param_W_shape[-1]
  352. rf = (
  353. kh * pre_stride[0] + pre_rf[0] - pre_stride[0],
  354. kw * pre_stride[1] + pre_rf[1] - pre_stride[1],
  355. )
  356. stride = (
  357. opnode.params["stride_h"] * pre_stride[0],
  358. opnode.params["stride_w"] * pre_stride[1],
  359. )
  360. opnode._rf = rf
  361. opnode._stride = stride
  362. return rf, stride
  363. class BatchConvBiasForward(OpNode):
  364. type = "BatchConvBias"
  365. opdef = builtin.BatchConvBias
  366. @classmethod
  367. def load(cls, opr):
  368. obj = super(BatchConvBiasForward, cls).load(opr)
  369. obj.params["dtype"] = opr.outputs[0].dtype
  370. return obj
  371. class BatchNormForward(OpNode):
  372. type = "BatchNorm"
  373. opdef = builtin.BatchNorm
  374. output_idx = -1
  375. class ROIAlignForward(OpNode):
  376. type = "ROIAlign"
  377. opdef = builtin.ROIAlign
  378. class WarpPerspectiveForward(OpNode):
  379. type = "WarpPerspective"
  380. opdef = builtin.WarpPerspective
  381. class WarpAffineForward(OpNode):
  382. type = "WarpAffine"
  383. opdef = builtin.WarpAffine
  384. class RemapForward(OpNode):
  385. type = "Remap"
  386. opdef = builtin.Remap
  387. class ResizeForward(OpNode):
  388. type = "Resize"
  389. opdef = builtin.Resize
  390. class IndexingOneHot(OpNode):
  391. type = "IndexingOneHot"
  392. opdef = builtin.IndexingOneHot
  393. class IndexingSetOneHot(OpNode):
  394. type = "IndexingSetOneHot"
  395. opdef = builtin.IndexingSetOneHot
  396. class Copy(OpNode):
  397. type = "Copy"
  398. opdef = builtin.Copy
  399. @classmethod
  400. def load(cls, opr):
  401. obj = super(Copy, cls).load(opr)
  402. obj.params["comp_node"] = opr.outputs[0].comp_node
  403. return obj
  404. class ArgsortForward(OpNode):
  405. type = "Argsort"
  406. opdef = builtin.Argsort
  407. class Argmax(OpNode):
  408. type = "Argmax"
  409. opdef = builtin.Argmax
  410. class Argmin(OpNode):
  411. type = "Argmin"
  412. opdef = builtin.Argmin
  413. class CondTake(OpNode):
  414. type = "CondTake"
  415. opdef = builtin.CondTake
  416. class TopK(OpNode):
  417. type = "TopK"
  418. opdef = builtin.TopK
  419. class NvOf(OpNode):
  420. type = "NvOf"
  421. opdef = builtin.NvOf
  422. class RNGOpr(OpNode):
  423. @classmethod
  424. def load(cls, opr):
  425. obj = super(RNGOpr, cls).load(opr)
  426. if len(obj.params) == 3:
  427. obj.opdef = builtin.GaussianRNG
  428. obj.type = "GaussianRNG"
  429. else:
  430. obj.opdef = builtin.UniformRNG
  431. obj.type = "UniformRNG"
  432. return obj
  433. class Linspace(OpNode):
  434. type = "Linspace"
  435. opdef = builtin.Linspace
  436. @classmethod
  437. def load(cls, opr):
  438. obj = super(Linspace, cls).load(opr)
  439. obj.params["comp_node"] = opr.outputs[0].comp_node
  440. return obj
  441. class Eye(OpNode):
  442. type = "Eye"
  443. opdef = builtin.Eye
  444. @classmethod
  445. def load(cls, opr):
  446. obj = super(Eye, cls).load(opr)
  447. obj.params["dtype"] = opr.outputs[0].dtype
  448. obj.params["comp_node"] = opr.outputs[0].comp_node
  449. return obj
  450. class GetVarShape(OpNode):
  451. type = "GetVarShape"
  452. opdef = builtin.GetVarShape
  453. class Concat(OpNode):
  454. type = "Concat"
  455. opdef = builtin.Concat
  456. @classmethod
  457. def load(cls, opr):
  458. obj = super(Concat, cls).load(opr)
  459. obj.params["comp_node"] = Device("xpux").to_c()
  460. return obj
  461. class Broadcast(OpNode):
  462. type = "Broadcast"
  463. opdef = builtin.Broadcast
  464. class Identity(OpNode):
  465. type = "Identity"
  466. opdef = builtin.Identity
  467. class NMSKeep(OpNode):
  468. type = "NMSKeep"
  469. opdef = builtin.NMSKeep
  470. # class ParamPackSplit
  471. # class ParamPackConcat
  472. class Dimshuffle(OpNode):
  473. type = "Dimshuffle"
  474. opdef = builtin.Dimshuffle
  475. @classmethod
  476. def load(cls, opr):
  477. obj = super(Dimshuffle, cls).load(opr)
  478. del obj.params["ndim"]
  479. return obj
  480. class Reshape(OpNode):
  481. type = "Reshape"
  482. opdef = builtin.Reshape
  483. class AxisAddRemove(OpNode):
  484. type = "AxisAddRemove"
  485. @classmethod
  486. def load(cls, opr):
  487. obj = cls()
  488. obj.name = opr.name
  489. obj._opr = opr
  490. params = json.loads(opr.params)
  491. desc = params["desc"]
  492. method = None
  493. axis = []
  494. for i in desc:
  495. if method is None:
  496. method = i["method"]
  497. assert method == i["method"]
  498. axis.append(i["axisnum"])
  499. obj.params = {"axis": axis}
  500. obj.opdef = builtin.AddAxis if desc[0]["method"] == 0 else builtin.RemoveAxis
  501. return obj
  502. class IndexingBase(OpNode):
  503. @classmethod
  504. def load(cls, opr):
  505. obj = cls()
  506. obj.name = opr.name
  507. obj._opr = opr
  508. params = json.loads(opr.params)
  509. items = [
  510. [
  511. p["axis"],
  512. bool(p["begin"]),
  513. bool(p["end"]),
  514. bool(p["step"]),
  515. bool(p["idx"]),
  516. ]
  517. for p in params
  518. ]
  519. obj.params["items"] = items
  520. return obj
  521. class Subtensor(IndexingBase):
  522. type = "Subtensor"
  523. opdef = builtin.Subtensor
  524. class SetSubtensor(IndexingBase):
  525. type = "SetSubtensor"
  526. opdef = builtin.SetSubtensor
  527. class IncrSubtensor(IndexingBase):
  528. type = "IncrSubtensor"
  529. opdef = builtin.IncrSubtensor
  530. class IndexingMultiAxisVec(IndexingBase):
  531. type = "IndexingMultiAxisVec"
  532. opdef = builtin.IndexingMultiAxisVec
  533. class IndexingSetMultiAxisVec(IndexingBase):
  534. type = "IndexingSetMultiAxisVec"
  535. opdef = builtin.IndexingSetMultiAxisVec
  536. class IndexingIncrMultiAxisVec(IndexingBase):
  537. type = "IndexingIncrMultiAxisVec"
  538. opdef = builtin.IndexingIncrMultiAxisVec
  539. class MeshIndexing(IndexingBase):
  540. type = "MeshIndexing"
  541. opdef = builtin.MeshIndexing
  542. class SetMeshIndexing(IndexingBase):
  543. type = "SetMeshIndexing"
  544. opdef = builtin.SetMeshIndexing
  545. class IncrMeshIndexing(IndexingBase):
  546. type = "IncrMeshIndexing"
  547. opdef = builtin.IncrMeshIndexing
  548. class BatchedMeshIndexing(IndexingBase):
  549. type = "BatchedMeshIndexing"
  550. opdef = builtin.BatchedMeshIndexing
  551. class BatchedSetMeshIndexing(IndexingBase):
  552. type = "BatchedSetMeshIndexing"
  553. opdef = builtin.BatchedSetMeshIndexing
  554. class BatchedIncrMeshIndexing(IndexingBase):
  555. type = "BatchedIncrMeshIndexing"
  556. opdef = builtin.BatchedIncrMeshIndexing
  557. # class CollectiveComm
  558. # class RemoteSend
  559. # class RemoteRecv
  560. # class TQT
  561. # class FakeQuant
  562. # class InplaceAdd
  563. class AssertEqual(OpNode):
  564. type = "AssertEqual"
  565. opdef = builtin.AssertEqual
  566. class CvtColorForward(OpNode):
  567. type = "CvtColor"
  568. opdef = builtin.CvtColor

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台