You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_node.py 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import json
  10. import sys
  11. from typing import Sequence
  12. import numpy as np
  13. from ..core import _imperative_rt as rt
  14. from ..core._imperative_rt.core2 import SymbolVar, apply
  15. from ..core._trace_option import use_symbolic_shape
  16. from ..core._wrap import Device
  17. from ..core.ops import builtin
  18. from ..core.tensor.array_method import ArrayMethodMixin
  19. from ..core.tensor.megbrain_graph import OutputNode
  20. from .comp_graph_tools import replace_vars
  21. from .module_stats import (
  22. preprocess_receptive_field,
  23. register_flops,
  24. register_receptive_field,
  25. )
  26. class NetworkNode:
  27. pass
  28. class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)):
  29. pass
  30. class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
  31. def __init__(self, var=None, *, owner_opr=None, name=None):
  32. SymbolVar.__init__(self, var)
  33. self.owner = owner_opr
  34. self.name = name
  35. self.id = id(self)
  36. @classmethod
  37. def load(cls, sym_var, owner_opr):
  38. obj = cls()
  39. obj.var = sym_var # mgb varnode
  40. obj.name = sym_var.name
  41. obj.owner = owner_opr
  42. return obj
  43. def _get_var_shape(self, axis=None):
  44. opdef = (
  45. builtin.GetVarShape() if axis is None else builtin.GetVarShape(axis=axis)
  46. )
  47. return apply(opdef, self)[0]
  48. @property
  49. def partial_shape(self):
  50. """Return the tuple type inferred shape of VarNode
  51. """
  52. return tuple(self._get_var_shape().numpy())
  53. def shapeof(self, axis):
  54. """Return the symbolic shape of axis
  55. """
  56. return self._get_var_shape(axis=axis) if self.var else None
  57. @property
  58. def _tuple_shape(self):
  59. return self.partial_shape
  60. @property
  61. def shape(self):
  62. """Return the symbolic shape if using set_symbolic_shape(True)
  63. else inferred shape
  64. """
  65. rst = None
  66. if self.var:
  67. try:
  68. rst = self.var.shape
  69. except:
  70. rst = None
  71. if not use_symbolic_shape():
  72. return rst
  73. return self._get_var_shape() if self.var else None
  74. @property
  75. def dtype(self):
  76. return self.var.dtype if self.var else None
  77. def __bool__(self):
  78. return False
  79. __index__ = None
  80. __int__ = None
  81. __float__ = None
  82. __complex__ = None
  83. def __hash__(self):
  84. return id(self)
  85. def numpy(self):
  86. o = OutputNode(self.var)
  87. self.graph.compile(o.outputs).execute()
  88. return o.get_value().numpy()
  89. def _reset(self, other):
  90. if not isinstance(other, VarNode):
  91. assert self.graph, "VarNode _reset must have graph"
  92. node = ImmutableTensor(other, graph=self.graph)
  93. node.compile(self.graph)
  94. other = node.outputs[0]
  95. if self.owner is not None:
  96. idx = self.owner.outputs.index(self)
  97. self.owner.outputs[idx] = VarNode(
  98. self.var, owner_opr=self.owner, name=self.var.name
  99. )
  100. self.var = other.var
  101. self.owner = None
  102. def set_owner_opr(self, owner_opr):
  103. self.owner = owner_opr
  104. class OpNode(NetworkNode):
  105. opdef = None
  106. type = None
  107. def __init__(self):
  108. self.inputs = []
  109. self.outputs = []
  110. self.params = {}
  111. self._opr = None # mgb opnode
  112. self.id = id(self)
  113. @classmethod
  114. def load(cls, opr):
  115. obj = cls()
  116. obj.params = json.loads(opr.params)
  117. obj.name = opr.name
  118. obj._opr = opr
  119. return obj
  120. def compile(self, graph=None):
  121. op = self.opdef(**self.params)
  122. args = [i.var for i in self.inputs]
  123. outputs = rt.invoke_op(op, args)
  124. assert len(outputs) == len(self.outputs)
  125. self._opr = outputs[0].owner
  126. for i in range(len(self.outputs)):
  127. self.outputs[i].var = outputs[i]
  128. self.outputs[i].var.name = self.outputs[i].name
  129. assert self.outputs[i].owner is self
  130. def add_inp_var(self, x):
  131. self.inputs.append(x)
  132. def add_out_var(self, x):
  133. self.outputs.append(x)
  134. def __repr__(self):
  135. return "%s{%s}" % (self.name, self.type)
  136. def str_to_mge_class(classname):
  137. # TODO: use megbrain C++ RTTI to replace type string
  138. if classname == "RNGOpr<MegDNNOpr>":
  139. classname = "RNGOpr"
  140. oprcls = getattr(sys.modules[__name__], classname, None)
  141. return oprcls if oprcls else ReadOnlyOpNode
  142. class Host2DeviceCopy(OpNode):
  143. type = "Host2DeviceCopy"
  144. def __init__(self, shape=None, dtype=None, name=None, device=None):
  145. super().__init__()
  146. self.shape = shape
  147. self.dtype = dtype
  148. self.name = name
  149. self.device = Device(device).to_c() if device else Device("xpux").to_c()
  150. self.outputs = []
  151. @classmethod
  152. def load(cls, opr):
  153. self = cls()
  154. self.outputs = []
  155. assert len(opr.outputs) == 1, "wrong number of outputs"
  156. self.shape = opr.outputs[0].shape
  157. self.dtype = opr.outputs[0].dtype
  158. self.name = opr.outputs[0].name
  159. self.device = opr.outputs[0].comp_node
  160. self._opr = opr
  161. return self
  162. def compile(self, graph):
  163. outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name)
  164. self._opr = outputs.owner
  165. if len(self.outputs) == 0:
  166. self.outputs.append(VarNode(owner_opr=self, name=self.name))
  167. self.outputs[0].var = outputs
  168. assert self.outputs[0].owner is self
  169. class ImmutableTensor(OpNode):
  170. type = "ImmutableTensor"
  171. def __init__(self, data=None, name=None, device=None, graph=None):
  172. super().__init__()
  173. self.name = name
  174. self.outputs = []
  175. self.graph = graph
  176. if data is not None:
  177. self.set_value(data, device)
  178. @property
  179. def device(self):
  180. return self._opr.outputs[0].comp_node if self._opr else None
  181. @device.setter
  182. def device(self, device):
  183. self.set_value(self.numpy(), device)
  184. @property
  185. def shape(self):
  186. return self.outputs[0].shape
  187. @property
  188. def dtype(self):
  189. return self._opr.outputs[0].dtype if self._opr else None
  190. def numpy(self):
  191. return self._opr.outputs[0].value if self._opr else None
  192. def set_value(self, data, device=None):
  193. assert self.graph is not None
  194. cn = device if device else self.device
  195. assert isinstance(data, (int, float, Sequence, np.ndarray))
  196. if not isinstance(data, np.ndarray):
  197. data = np.array(data)
  198. if data.dtype == np.float64:
  199. data = data.astype(np.float32)
  200. elif data.dtype == np.int64:
  201. data = data.astype(np.int32)
  202. varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name)
  203. if len(self.outputs) == 0:
  204. self.outputs.append(VarNode(owner_opr=self, name=self.name))
  205. self.outputs[0].var = varnode
  206. self._opr = varnode.owner
  207. @classmethod
  208. def load(cls, opr):
  209. self = cls()
  210. self.outputs = []
  211. self._opr = opr
  212. self.name = opr.outputs[0].name
  213. self.graph = opr.graph
  214. return self
  215. def compile(self, graph):
  216. assert self.outputs[0].var is self._opr.outputs[0]
  217. assert self.outputs[0].owner is self
  218. if self.graph != graph:
  219. self.graph = graph
  220. self.set_value(self.numpy())
  221. if self.name is not None:
  222. self.outputs[0].var.name = self.name
  223. class ReadOnlyOpNode(OpNode):
  224. @classmethod
  225. def load(cls, opr):
  226. obj = super(ReadOnlyOpNode, cls).load(opr)
  227. obj.type = opr.type
  228. return obj
  229. def compile(self):
  230. assert self._opr is not None
  231. assert len(self.inputs) == len(self._opr.inputs)
  232. assert len(self.outputs) == len(self._opr.outputs)
  233. repl_dict = {}
  234. for ind, i in enumerate(self.inputs):
  235. if i.var != self._opr.inputs[ind]:
  236. repl_dict[self._opr.inputs[ind]] = i.var
  237. if bool(repl_dict):
  238. out_vars = replace_vars(self._opr.outputs, repl_dict)
  239. for ind, o in enumerate(self.outputs):
  240. o.var = out_vars[ind]
  241. class Elemwise(OpNode):
  242. type = "Elemwise"
  243. opdef = builtin.Elemwise
  244. def __repr__(self):
  245. return "%s{Elemwise:%s}" % (self.name, self.params["mode"])
  246. class ElemwiseMultiType(OpNode):
  247. type = "ElemwiseMultiType"
  248. opdef = builtin.ElemwiseMultiType
  249. def __repr__(self):
  250. return "%s{ElemwiseMultiType:%s}" % (self.name, self.params["mode"])
  251. @classmethod
  252. def load(cls, opr):
  253. obj = super(ElemwiseMultiType, cls).load(opr)
  254. obj.params["dtype"] = opr.outputs[0].dtype
  255. return obj
  256. @register_flops(Elemwise, ElemwiseMultiType)
  257. def flops_elemwise(opnode: Elemwise, inputs, outputs):
  258. return np.prod(outputs[0].shape)
  259. class Reduce(OpNode):
  260. type = "Reduce"
  261. opdef = builtin.Reduce
  262. class TypeCvt(OpNode):
  263. type = "TypeCvt"
  264. opdef = builtin.TypeCvt
  265. @classmethod
  266. def load(cls, opr):
  267. obj = super(TypeCvt, cls).load(opr)
  268. t_dtype = opr.outputs[0].dtype
  269. obj.params["dtype"] = t_dtype
  270. return obj
  271. class MatrixInverse(OpNode):
  272. type = "MatrixInverse"
  273. opdef = builtin.MatrixInverse
  274. class MatrixMul(OpNode):
  275. type = "MatrixMul"
  276. opdef = builtin.MatrixMul
  277. @register_flops(MatrixMul)
  278. def flops_matmul(opnode: MatrixMul, inputs, outputs):
  279. assert len(inputs[0].shape) == 2 and len(outputs[0].shape) == 2
  280. mid_shape = inputs[0].shape[1]
  281. return np.prod(outputs[0].shape) * mid_shape
  282. class BatchedMatrixMul(OpNode):
  283. type = "BatchedMatmul"
  284. opdef = builtin.BatchedMatrixMul
  285. @register_flops(BatchedMatrixMul)
  286. def flops_batchmatmul(opnode: BatchedMatrixMul, inputs, outputs):
  287. assert len(inputs[0].shape) == 3 and len(outputs[0].shape) == 3
  288. mid_shape = inputs[0].shape[2]
  289. return np.prod(outputs[0].shape) * mid_shape
  290. class Dot(OpNode):
  291. type = "Dot"
  292. opdef = builtin.Dot
  293. class SVD(OpNode):
  294. type = "SVD"
  295. opdef = builtin.SVD
  296. class ConvolutionForward(OpNode):
  297. type = "Convolution"
  298. opdef = builtin.Convolution
  299. class ConvolutionBackwardData(OpNode):
  300. type = "ConvTranspose"
  301. opdef = builtin.ConvolutionBackwardData
  302. class DeformableConvForward(OpNode):
  303. type = "DeformableConv"
  304. opdef = builtin.DeformableConv
  305. class GroupLocalForward(OpNode):
  306. type = "GroupLocal"
  307. opdef = builtin.GroupLocal
  308. class PoolingForward(OpNode):
  309. type = "Pooling"
  310. opdef = builtin.Pooling
  311. class AdaptivePoolingForward(OpNode):
  312. type = "AdaptivePooling"
  313. opdef = builtin.AdaptivePooling
  314. class ROIPoolingForward(OpNode):
  315. type = "ROIPooling"
  316. opdef = builtin.ROIPooling
  317. class DeformablePSROIPoolingForward(OpNode):
  318. type = "DeformablePSROIPooling"
  319. opdef = builtin.DeformablePSROIPooling
  320. class ConvBiasForward(OpNode):
  321. type = "ConvBias"
  322. opdef = builtin.ConvBias
  323. @classmethod
  324. def load(cls, opr):
  325. obj = super(ConvBiasForward, cls).load(opr)
  326. obj.params["dtype"] = opr.outputs[0].dtype
  327. return obj
  328. @register_flops(
  329. ConvolutionForward, ConvBiasForward,
  330. )
  331. def flops_conv(opnode: ConvolutionForward, inputs, outputs):
  332. param_W_shape = inputs[1].shape
  333. kh = param_W_shape[-2]
  334. kw = param_W_shape[-1]
  335. if len(param_W_shape) == 5:
  336. num_input = param_W_shape[2]
  337. else:
  338. num_input = param_W_shape[1]
  339. NCHW = np.prod(outputs[0].shape)
  340. bias = 1 if isinstance(opnode, ConvBiasForward) else 0
  341. # N x Cout x H x W x (Cin x Kw x Kh)
  342. return NCHW * (num_input * kw * kh + bias)
  343. @register_receptive_field(ConvolutionForward, ConvBiasForward)
  344. def receptive_field(opnode: ConvolutionForward, inputs, outputs):
  345. pre_rf, pre_stride = preprocess_receptive_field(opnode, inputs, outputs)
  346. param_W_shape = inputs[1].shape
  347. kh = param_W_shape[-2]
  348. kw = param_W_shape[-1]
  349. rf = (
  350. kh * pre_stride[0] + pre_rf[0] - pre_stride[0],
  351. kw * pre_stride[1] + pre_rf[1] - pre_stride[1],
  352. )
  353. stride = (
  354. opnode.params["stride_h"] * pre_stride[0],
  355. opnode.params["stride_w"] * pre_stride[1],
  356. )
  357. opnode._rf = rf
  358. opnode._stride = stride
  359. return rf, stride
  360. class BatchConvBiasForward(OpNode):
  361. type = "BatchConvBias"
  362. opdef = builtin.BatchConvBias
  363. @classmethod
  364. def load(cls, opr):
  365. obj = super(BatchConvBiasForward, cls).load(opr)
  366. obj.params["dtype"] = opr.outputs[0].dtype
  367. return obj
  368. class BatchNormForward(OpNode):
  369. type = "BatchNorm"
  370. opdef = builtin.BatchNorm
  371. output_idx = -1
  372. class ROIAlignForward(OpNode):
  373. type = "ROIAlign"
  374. opdef = builtin.ROIAlign
  375. class WarpPerspectiveForward(OpNode):
  376. type = "WarpPerspective"
  377. opdef = builtin.WarpPerspective
  378. class WarpAffineForward(OpNode):
  379. type = "WarpAffine"
  380. opdef = builtin.WarpAffine
  381. class RemapForward(OpNode):
  382. type = "Remap"
  383. opdef = builtin.Remap
  384. class ResizeForward(OpNode):
  385. type = "Resize"
  386. opdef = builtin.Resize
  387. class IndexingOneHot(OpNode):
  388. type = "IndexingOneHot"
  389. opdef = builtin.IndexingOneHot
  390. class IndexingSetOneHot(OpNode):
  391. type = "IndexingSetOneHot"
  392. opdef = builtin.IndexingSetOneHot
  393. class Copy(OpNode):
  394. type = "Copy"
  395. opdef = builtin.Copy
  396. @classmethod
  397. def load(cls, opr):
  398. obj = super(Copy, cls).load(opr)
  399. obj.params["comp_node"] = opr.outputs[0].comp_node
  400. return obj
  401. class ArgsortForward(OpNode):
  402. type = "Argsort"
  403. opdef = builtin.Argsort
  404. class Argmax(OpNode):
  405. type = "Argmax"
  406. opdef = builtin.Argmax
  407. class Argmin(OpNode):
  408. type = "Argmin"
  409. opdef = builtin.Argmin
  410. class CondTake(OpNode):
  411. type = "CondTake"
  412. opdef = builtin.CondTake
  413. class TopK(OpNode):
  414. type = "TopK"
  415. opdef = builtin.TopK
  416. class NvOf(OpNode):
  417. type = "NvOf"
  418. opdef = builtin.NvOf
  419. class RNGOpr(OpNode):
  420. @classmethod
  421. def load(cls, opr):
  422. obj = super(RNGOpr, cls).load(opr)
  423. if len(obj.params) == 3:
  424. obj.opdef = builtin.GaussianRNG
  425. obj.type = "GaussianRNG"
  426. else:
  427. obj.opdef = builtin.UniformRNG
  428. obj.type = "UniformRNG"
  429. return obj
  430. class Linspace(OpNode):
  431. type = "Linspace"
  432. opdef = builtin.Linspace
  433. @classmethod
  434. def load(cls, opr):
  435. obj = super(Linspace, cls).load(opr)
  436. obj.params["comp_node"] = opr.outputs[0].comp_node
  437. return obj
  438. class Eye(OpNode):
  439. type = "Eye"
  440. opdef = builtin.Eye
  441. @classmethod
  442. def load(cls, opr):
  443. obj = super(Eye, cls).load(opr)
  444. obj.params["dtype"] = opr.outputs[0].dtype
  445. obj.params["comp_node"] = opr.outputs[0].comp_node
  446. return obj
  447. class GetVarShape(OpNode):
  448. type = "GetVarShape"
  449. opdef = builtin.GetVarShape
  450. class Concat(OpNode):
  451. type = "Concat"
  452. opdef = builtin.Concat
  453. @classmethod
  454. def load(cls, opr):
  455. obj = super(Concat, cls).load(opr)
  456. obj.params["comp_node"] = Device("xpux").to_c()
  457. return obj
  458. class Broadcast(OpNode):
  459. type = "Broadcast"
  460. opdef = builtin.Broadcast
  461. class Identity(OpNode):
  462. type = "Identity"
  463. opdef = builtin.Identity
  464. class NMSKeep(OpNode):
  465. type = "NMSKeep"
  466. opdef = builtin.NMSKeep
  467. # class ParamPackSplit
  468. # class ParamPackConcat
  469. class Dimshuffle(OpNode):
  470. type = "Dimshuffle"
  471. opdef = builtin.Dimshuffle
  472. @classmethod
  473. def load(cls, opr):
  474. obj = super(Dimshuffle, cls).load(opr)
  475. del obj.params["ndim"]
  476. return obj
  477. class Reshape(OpNode):
  478. type = "Reshape"
  479. opdef = builtin.Reshape
  480. class AxisAddRemove(OpNode):
  481. type = "AxisAddRemove"
  482. @classmethod
  483. def load(cls, opr):
  484. obj = cls()
  485. obj.name = opr.name
  486. obj._opr = opr
  487. params = json.loads(opr.params)
  488. desc = params["desc"]
  489. method = None
  490. axis = []
  491. for i in desc:
  492. if method is None:
  493. method = i["method"]
  494. assert method == i["method"]
  495. axis.append(i["axisnum"])
  496. obj.params = {"axis": axis}
  497. obj.opdef = builtin.AddAxis if desc[0]["method"] == 0 else builtin.RemoveAxis
  498. return obj
  499. class IndexingBase(OpNode):
  500. @classmethod
  501. def load(cls, opr):
  502. obj = cls()
  503. obj.name = opr.name
  504. obj._opr = opr
  505. params = json.loads(opr.params)
  506. items = [
  507. [
  508. p["axis"],
  509. bool(p["begin"]),
  510. bool(p["end"]),
  511. bool(p["step"]),
  512. bool(p["idx"]),
  513. ]
  514. for p in params
  515. ]
  516. obj.params["items"] = items
  517. return obj
  518. class Subtensor(IndexingBase):
  519. type = "Subtensor"
  520. opdef = builtin.Subtensor
  521. class SetSubtensor(IndexingBase):
  522. type = "SetSubtensor"
  523. opdef = builtin.SetSubtensor
  524. class IncrSubtensor(IndexingBase):
  525. type = "IncrSubtensor"
  526. opdef = builtin.IncrSubtensor
  527. class IndexingMultiAxisVec(IndexingBase):
  528. type = "IndexingMultiAxisVec"
  529. opdef = builtin.IndexingMultiAxisVec
  530. class IndexingSetMultiAxisVec(IndexingBase):
  531. type = "IndexingSetMultiAxisVec"
  532. opdef = builtin.IndexingSetMultiAxisVec
  533. class IndexingIncrMultiAxisVec(IndexingBase):
  534. type = "IndexingIncrMultiAxisVec"
  535. opdef = builtin.IndexingIncrMultiAxisVec
  536. class MeshIndexing(IndexingBase):
  537. type = "MeshIndexing"
  538. opdef = builtin.MeshIndexing
  539. class SetMeshIndexing(IndexingBase):
  540. type = "SetMeshIndexing"
  541. opdef = builtin.SetMeshIndexing
  542. class IncrMeshIndexing(IndexingBase):
  543. type = "IncrMeshIndexing"
  544. opdef = builtin.IncrMeshIndexing
  545. class BatchedMeshIndexing(IndexingBase):
  546. type = "BatchedMeshIndexing"
  547. opdef = builtin.BatchedMeshIndexing
  548. class BatchedSetMeshIndexing(IndexingBase):
  549. type = "BatchedSetMeshIndexing"
  550. opdef = builtin.BatchedSetMeshIndexing
  551. class BatchedIncrMeshIndexing(IndexingBase):
  552. type = "BatchedIncrMeshIndexing"
  553. opdef = builtin.BatchedIncrMeshIndexing
  554. # class CollectiveComm
  555. # class RemoteSend
  556. # class RemoteRecv
  557. # class TQT
  558. # class FakeQuant
  559. # class InplaceAdd
  560. class AssertEqual(OpNode):
  561. type = "AssertEqual"
  562. opdef = builtin.AssertEqual
  563. class CvtColorForward(OpNode):
  564. type = "CvtColor"
  565. opdef = builtin.CvtColor

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台