You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_node.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import json
  10. import sys
  11. from typing import Sequence
  12. import numpy as np
  13. from ..core import _imperative_rt as rt
  14. from ..core._imperative_rt.core2 import SymbolVar, apply
  15. from ..core._trace_option import use_symbolic_shape
  16. from ..core._wrap import Device
  17. from ..core.ops import builtin
  18. from ..core.tensor.array_method import ArrayMethodMixin
  19. from .comp_graph_tools import replace_vars
  20. from .module_stats import (
  21. preprocess_receptive_field,
  22. register_flops,
  23. register_receptive_field,
  24. )
  25. class NetworkNode:
  26. pass
  27. class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)):
  28. pass
  29. class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
  30. def __init__(self, var=None, *, owner_opr=None, name=None):
  31. SymbolVar.__init__(self, var)
  32. self.users = [] # List[OpNode]
  33. self.owner = owner_opr
  34. self.name = name
  35. self.id = id(self)
  36. @classmethod
  37. def load(cls, sym_var, owner_opr):
  38. obj = cls()
  39. obj.var = sym_var # mgb varnode
  40. obj.name = sym_var.name
  41. obj.owner = owner_opr
  42. return obj
  43. def _get_var_shape(self, axis=None):
  44. opdef = (
  45. builtin.GetVarShape() if axis is None else builtin.GetVarShape(axis=axis)
  46. )
  47. return apply(opdef, self)[0]
  48. @property
  49. def partial_shape(self):
  50. r"""Return the tuple type inferred shape of VarNode"""
  51. return tuple(self._get_var_shape().numpy())
  52. def shapeof(self, axis):
  53. r"""Return the symbolic shape of axis"""
  54. return self._get_var_shape(axis=axis) if self.var else None
  55. @property
  56. def _tuple_shape(self):
  57. return self.partial_shape
  58. @property
  59. def shape(self):
  60. r"""Return the symbolic shape if using set_symbolic_shape(True)
  61. else inferred shape
  62. """
  63. rst = None
  64. if self.var:
  65. try:
  66. rst = self.var.shape
  67. except:
  68. rst = None
  69. if not use_symbolic_shape():
  70. return rst
  71. return self._get_var_shape() if self.var else None
  72. @property
  73. def dtype(self):
  74. return self.var.dtype if self.var else None
  75. @property
  76. def ndim(self):
  77. return super().ndim
  78. def __bool__(self):
  79. return False
  80. __index__ = None
  81. __int__ = None
  82. __float__ = None
  83. __complex__ = None
  84. def __hash__(self):
  85. return id(self)
  86. def numpy(self):
  87. return super().numpy()
  88. def _reset(self, other):
  89. if not isinstance(other, VarNode):
  90. assert self.graph, "VarNode _reset must have graph"
  91. node = ImmutableTensor(other, graph=self.graph)
  92. node.compile(self.graph)
  93. other = node.outputs[0]
  94. if self.owner is not None:
  95. idx = self.owner.outputs.index(self)
  96. self.owner.outputs[idx] = VarNode(
  97. self.var, owner_opr=self.owner, name=self.var.name
  98. )
  99. self.var = other.var
  100. self.owner = None
  101. def set_owner_opr(self, owner_opr):
  102. self.owner = owner_opr
  103. class OpNode(NetworkNode):
  104. opdef = None
  105. type = None
  106. def __init__(self):
  107. self.inputs = []
  108. self.outputs = []
  109. self.params = {}
  110. self._opr = None # mgb opnode
  111. @property
  112. def id(self):
  113. return id(self)
  114. @property
  115. def priority(self):
  116. if self._opr is not None:
  117. return (self._opr.priority, self._opr.id)
  118. return (0, 0)
  119. @classmethod
  120. def load(cls, opr):
  121. obj = cls()
  122. obj.params = json.loads(opr.params)
  123. obj.name = opr.name
  124. obj._opr = opr
  125. return obj
  126. def compile(self):
  127. if (
  128. self._opr is None
  129. or len(self._opr.inputs) != len(self.inputs)
  130. or any([i != j.var for i, j in zip(self._opr.inputs, self.inputs)])
  131. ):
  132. op = self.opdef(**self.params)
  133. args = [i.var for i in self.inputs]
  134. outputs = rt.invoke_op(op, args)
  135. assert len(outputs) == len(self.outputs)
  136. self._opr = outputs[0].owner
  137. for i in range(len(self.outputs)):
  138. self.outputs[i].var = outputs[i]
  139. self.outputs[i].var.name = self.outputs[i].name
  140. assert self.outputs[i].owner is self
  141. def add_inp_var(self, x):
  142. self.inputs.append(x)
  143. def add_out_var(self, x):
  144. self.outputs.append(x)
  145. def __repr__(self):
  146. return "%s{%s}" % (self.name, self.type)
  147. def str_to_mge_class(classname):
  148. # TODO: use megbrain C++ RTTI to replace type string
  149. if classname == "RNGOpr<MegDNNOpr>":
  150. classname = "RNGOpr"
  151. oprcls = getattr(sys.modules[__name__], classname, None)
  152. return oprcls if oprcls else ReadOnlyOpNode
  153. class Host2DeviceCopy(OpNode):
  154. type = "Host2DeviceCopy"
  155. def __init__(self, shape=None, dtype=None, name=None, device=None):
  156. super().__init__()
  157. self.shape = shape
  158. self.dtype = dtype
  159. self.name = name
  160. self.device = Device(device).to_c() if device else Device("xpux").to_c()
  161. self.outputs = []
  162. @classmethod
  163. def load(cls, opr):
  164. self = cls()
  165. self.outputs = []
  166. assert len(opr.outputs) == 1, "wrong number of outputs"
  167. self.shape = opr.outputs[0].shape
  168. self.dtype = opr.outputs[0].dtype
  169. self.name = opr.outputs[0].name
  170. self.device = opr.outputs[0].comp_node
  171. self._opr = opr
  172. return self
  173. def compile(self, graph):
  174. if (
  175. self._opr is None
  176. or self._opr.graph != graph
  177. or self._opr.outputs[0].comp_node != self.device
  178. or self._opr.outputs[0].shape != self.shape
  179. or self._opr.outputs[0].dtype != self.dtype
  180. ):
  181. outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name)
  182. self._opr = outputs.owner
  183. if len(self.outputs) == 0:
  184. self.outputs.append(VarNode(owner_opr=self, name=self.name))
  185. self.outputs[0].var = outputs
  186. assert self.outputs[0].owner is self
  187. class ConstOpBase(OpNode):
  188. type = "ConstOpBase"
  189. def __init__(self, data=None, name=None, device=None, graph=None):
  190. assert type(self) is not ConstOpBase, "ConstOpBase cannot be instantiated"
  191. super().__init__()
  192. self.name = name
  193. self.outputs = []
  194. self.graph = graph
  195. if data is not None:
  196. self.set_value(data, device)
  197. @property
  198. def device(self):
  199. return self._opr.outputs[0].comp_node if self._opr else None
  200. @device.setter
  201. def device(self, device):
  202. self.set_value(self.numpy(), device)
  203. @property
  204. def shape(self):
  205. return self.outputs[0].shape
  206. @property
  207. def dtype(self):
  208. return self._opr.outputs[0].dtype if self._opr else None
  209. def numpy(self):
  210. return self.outputs[0].numpy()
  211. def set_value(self, data, device=None):
  212. assert self.graph is not None
  213. cn = device if device else self.device
  214. assert isinstance(data, (int, float, Sequence, np.ndarray))
  215. if not isinstance(data, np.ndarray):
  216. data = np.array(data)
  217. if data.dtype == np.float64:
  218. data = data.astype(np.float32)
  219. elif data.dtype == np.int64:
  220. data = data.astype(np.int32)
  221. varnode = type(self).rt_fun(self.graph, data, cn, data.dtype, self.name)
  222. if len(self.outputs) == 0:
  223. self.outputs.append(VarNode(owner_opr=self, name=self.name))
  224. self.outputs[0].var = varnode
  225. self._opr = varnode.owner
  226. @classmethod
  227. def load(cls, opr):
  228. self = cls()
  229. self.outputs = []
  230. self._opr = opr
  231. self.name = opr.outputs[0].name
  232. self.graph = opr.graph
  233. return self
  234. def compile(self, graph):
  235. assert self.outputs[0].var is self._opr.outputs[0]
  236. assert self.outputs[0].owner is self
  237. if self.graph != graph:
  238. self.graph = graph
  239. self.set_value(self.numpy())
  240. if self.name is not None:
  241. self.outputs[0].var.name = self.name
  242. class ImmutableTensor(ConstOpBase):
  243. type = "ImmutableTensor"
  244. rt_fun = rt.make_const
  245. class SharedDeviceTensor(ConstOpBase):
  246. type = "SharedDeviceTensor"
  247. rt_fun = rt.make_shared
  248. class ReadOnlyOpNode(OpNode):
  249. @classmethod
  250. def load(cls, opr):
  251. obj = super(ReadOnlyOpNode, cls).load(opr)
  252. obj.type = opr.type
  253. return obj
  254. def compile(self):
  255. assert self._opr is not None
  256. assert len(self.inputs) == len(self._opr.inputs)
  257. assert len(self.outputs) == len(self._opr.outputs)
  258. repl_dict = {}
  259. for ind, i in enumerate(self.inputs):
  260. if i.var != self._opr.inputs[ind]:
  261. repl_dict[self._opr.inputs[ind]] = i.var
  262. if bool(repl_dict):
  263. out_vars = replace_vars(self._opr.outputs, repl_dict)
  264. for ind, o in enumerate(self.outputs):
  265. o.var = out_vars[ind]
  266. class Elemwise(OpNode):
  267. type = "Elemwise"
  268. opdef = builtin.Elemwise
  269. def __repr__(self):
  270. return "%s{Elemwise:%s}" % (self.name, self.params["mode"])
  271. class ElemwiseMultiType(OpNode):
  272. type = "ElemwiseMultiType"
  273. opdef = builtin.ElemwiseMultiType
  274. def __repr__(self):
  275. return "%s{ElemwiseMultiType:%s}" % (self.name, self.params["mode"])
  276. @classmethod
  277. def load(cls, opr):
  278. obj = super(ElemwiseMultiType, cls).load(opr)
  279. obj.params["dtype"] = opr.outputs[0].dtype
  280. return obj
  281. @register_flops(Elemwise, ElemwiseMultiType)
  282. def flops_elemwise(opnode: Elemwise, inputs, outputs):
  283. return np.prod(outputs[0].shape)
  284. class Reduce(OpNode):
  285. type = "Reduce"
  286. opdef = builtin.Reduce
  287. class TypeCvt(OpNode):
  288. type = "TypeCvt"
  289. opdef = builtin.TypeCvt
  290. @classmethod
  291. def load(cls, opr):
  292. obj = super(TypeCvt, cls).load(opr)
  293. t_dtype = opr.outputs[0].dtype
  294. obj.params["dtype"] = t_dtype
  295. return obj
  296. class MatrixInverse(OpNode):
  297. type = "MatrixInverse"
  298. opdef = builtin.MatrixInverse
  299. class MatrixMul(OpNode):
  300. type = "MatrixMul"
  301. opdef = builtin.MatrixMul
  302. @register_flops(MatrixMul)
  303. def flops_matmul(opnode: MatrixMul, inputs, outputs):
  304. assert len(inputs[0].shape) == 2 and len(outputs[0].shape) == 2
  305. mid_shape = inputs[0].shape[1]
  306. return np.prod(outputs[0].shape) * mid_shape
  307. class BatchedMatrixMul(OpNode):
  308. type = "BatchedMatmul"
  309. opdef = builtin.BatchedMatrixMul
  310. @register_flops(BatchedMatrixMul)
  311. def flops_batchmatmul(opnode: BatchedMatrixMul, inputs, outputs):
  312. assert len(inputs[0].shape) == 3 and len(outputs[0].shape) == 3
  313. mid_shape = inputs[0].shape[2]
  314. return np.prod(outputs[0].shape) * mid_shape
  315. class Dot(OpNode):
  316. type = "Dot"
  317. opdef = builtin.Dot
  318. class SVD(OpNode):
  319. type = "SVD"
  320. opdef = builtin.SVD
  321. class ConvolutionForward(OpNode):
  322. type = "Convolution"
  323. opdef = builtin.Convolution
  324. class ConvolutionBackwardData(OpNode):
  325. type = "ConvTranspose"
  326. opdef = builtin.ConvolutionBackwardData
  327. class DeformableConvForward(OpNode):
  328. type = "DeformableConv"
  329. opdef = builtin.DeformableConv
  330. class GroupLocalForward(OpNode):
  331. type = "GroupLocal"
  332. opdef = builtin.GroupLocal
  333. class PoolingForward(OpNode):
  334. type = "Pooling"
  335. opdef = builtin.Pooling
  336. class AdaptivePoolingForward(OpNode):
  337. type = "AdaptivePooling"
  338. opdef = builtin.AdaptivePooling
  339. class ROIPoolingForward(OpNode):
  340. type = "ROIPooling"
  341. opdef = builtin.ROIPooling
  342. class DeformablePSROIPoolingForward(OpNode):
  343. type = "DeformablePSROIPooling"
  344. opdef = builtin.DeformablePSROIPooling
  345. class ConvBiasForward(OpNode):
  346. type = "ConvBias"
  347. opdef = builtin.ConvBias
  348. @classmethod
  349. def load(cls, opr):
  350. obj = super(ConvBiasForward, cls).load(opr)
  351. obj.params["dtype"] = opr.outputs[0].dtype
  352. return obj
  353. @register_flops(
  354. ConvolutionForward, ConvBiasForward,
  355. )
  356. def flops_conv(opnode: ConvolutionForward, inputs, outputs):
  357. param_W_shape = inputs[1].shape
  358. kh = param_W_shape[-2]
  359. kw = param_W_shape[-1]
  360. if len(param_W_shape) == 5:
  361. num_input = param_W_shape[2]
  362. else:
  363. num_input = param_W_shape[1]
  364. NCHW = np.prod(outputs[0].shape)
  365. bias = 1 if isinstance(opnode, ConvBiasForward) else 0
  366. # N x Cout x H x W x (Cin x Kw x Kh)
  367. return NCHW * (num_input * kw * kh + bias)
  368. @register_receptive_field(ConvolutionForward, ConvBiasForward)
  369. def receptive_field(opnode: ConvolutionForward, inputs, outputs):
  370. pre_rf, pre_stride = preprocess_receptive_field(opnode, inputs, outputs)
  371. param_W_shape = inputs[1].shape
  372. kh = param_W_shape[-2]
  373. kw = param_W_shape[-1]
  374. rf = (
  375. kh * pre_stride[0] + pre_rf[0] - pre_stride[0],
  376. kw * pre_stride[1] + pre_rf[1] - pre_stride[1],
  377. )
  378. stride = (
  379. opnode.params["stride_h"] * pre_stride[0],
  380. opnode.params["stride_w"] * pre_stride[1],
  381. )
  382. opnode._rf = rf
  383. opnode._stride = stride
  384. return rf, stride
  385. class BatchConvBiasForward(OpNode):
  386. type = "BatchConvBias"
  387. opdef = builtin.BatchConvBias
  388. @classmethod
  389. def load(cls, opr):
  390. obj = super(BatchConvBiasForward, cls).load(opr)
  391. obj.params["dtype"] = opr.outputs[0].dtype
  392. return obj
  393. class BatchNormForward(OpNode):
  394. type = "BatchNorm"
  395. opdef = builtin.BatchNorm
  396. output_idx = -1
  397. class ROIAlignForward(OpNode):
  398. type = "ROIAlign"
  399. opdef = builtin.ROIAlign
  400. class WarpPerspectiveForward(OpNode):
  401. type = "WarpPerspective"
  402. opdef = builtin.WarpPerspective
  403. class WarpAffineForward(OpNode):
  404. type = "WarpAffine"
  405. opdef = builtin.WarpAffine
  406. class RemapForward(OpNode):
  407. type = "Remap"
  408. opdef = builtin.Remap
  409. class ResizeForward(OpNode):
  410. type = "Resize"
  411. opdef = builtin.Resize
  412. class IndexingOneHot(OpNode):
  413. type = "IndexingOneHot"
  414. opdef = builtin.IndexingOneHot
  415. class IndexingSetOneHot(OpNode):
  416. type = "IndexingSetOneHot"
  417. opdef = builtin.IndexingSetOneHot
  418. class Copy(OpNode):
  419. type = "Copy"
  420. opdef = builtin.Copy
  421. @classmethod
  422. def load(cls, opr):
  423. obj = super(Copy, cls).load(opr)
  424. obj.params["comp_node"] = opr.outputs[0].comp_node
  425. return obj
  426. class ArgsortForward(OpNode):
  427. type = "Argsort"
  428. opdef = builtin.Argsort
  429. class Argmax(OpNode):
  430. type = "Argmax"
  431. opdef = builtin.Argmax
  432. class Argmin(OpNode):
  433. type = "Argmin"
  434. opdef = builtin.Argmin
  435. class CondTake(OpNode):
  436. type = "CondTake"
  437. opdef = builtin.CondTake
  438. class TopK(OpNode):
  439. type = "TopK"
  440. opdef = builtin.TopK
  441. class NvOf(OpNode):
  442. type = "NvOf"
  443. opdef = builtin.NvOf
  444. class RNGOpr(OpNode):
  445. @classmethod
  446. def load(cls, opr):
  447. obj = super(RNGOpr, cls).load(opr)
  448. if len(obj.params) == 3:
  449. obj.opdef = builtin.GaussianRNG
  450. obj.type = "GaussianRNG"
  451. else:
  452. obj.opdef = builtin.UniformRNG
  453. obj.type = "UniformRNG"
  454. return obj
  455. class Linspace(OpNode):
  456. type = "Linspace"
  457. opdef = builtin.Linspace
  458. @classmethod
  459. def load(cls, opr):
  460. obj = super(Linspace, cls).load(opr)
  461. obj.params["comp_node"] = opr.outputs[0].comp_node
  462. return obj
  463. class Eye(OpNode):
  464. type = "Eye"
  465. opdef = builtin.Eye
  466. @classmethod
  467. def load(cls, opr):
  468. obj = super(Eye, cls).load(opr)
  469. obj.params["dtype"] = opr.outputs[0].dtype
  470. obj.params["comp_node"] = opr.outputs[0].comp_node
  471. return obj
  472. class GetVarShape(OpNode):
  473. type = "GetVarShape"
  474. opdef = builtin.GetVarShape
  475. class Concat(OpNode):
  476. type = "Concat"
  477. opdef = builtin.Concat
  478. @classmethod
  479. def load(cls, opr):
  480. obj = super(Concat, cls).load(opr)
  481. obj.params["comp_node"] = Device("xpux").to_c()
  482. return obj
  483. class Broadcast(OpNode):
  484. type = "Broadcast"
  485. opdef = builtin.Broadcast
  486. class Identity(OpNode):
  487. type = "Identity"
  488. opdef = builtin.Identity
  489. class NMSKeep(OpNode):
  490. type = "NMSKeep"
  491. opdef = builtin.NMSKeep
  492. # class ParamPackSplit
  493. # class ParamPackConcat
  494. class Dimshuffle(OpNode):
  495. type = "Dimshuffle"
  496. opdef = builtin.Dimshuffle
  497. @classmethod
  498. def load(cls, opr):
  499. obj = super(Dimshuffle, cls).load(opr)
  500. del obj.params["ndim"]
  501. return obj
  502. class Reshape(OpNode):
  503. type = "Reshape"
  504. opdef = builtin.Reshape
  505. class AxisAddRemove(OpNode):
  506. type = "AxisAddRemove"
  507. @classmethod
  508. def load(cls, opr):
  509. obj = cls()
  510. obj.name = opr.name
  511. obj._opr = opr
  512. params = json.loads(opr.params)
  513. desc = params["desc"]
  514. method = None
  515. axis = []
  516. for i in desc:
  517. if method is None:
  518. method = i["method"]
  519. assert method == i["method"]
  520. axis.append(i["axisnum"])
  521. obj.params = {"axis": axis}
  522. obj.opdef = builtin.AddAxis if desc[0]["method"] == 0 else builtin.RemoveAxis
  523. return obj
  524. class IndexingBase(OpNode):
  525. @classmethod
  526. def load(cls, opr):
  527. obj = cls()
  528. obj.name = opr.name
  529. obj._opr = opr
  530. params = json.loads(opr.params)
  531. items = [
  532. [
  533. p["axis"],
  534. bool(p["begin"]),
  535. bool(p["end"]),
  536. bool(p["step"]),
  537. bool(p["idx"]),
  538. ]
  539. for p in params
  540. ]
  541. obj.params["items"] = items
  542. return obj
  543. class Subtensor(IndexingBase):
  544. type = "Subtensor"
  545. opdef = builtin.Subtensor
  546. class SetSubtensor(IndexingBase):
  547. type = "SetSubtensor"
  548. opdef = builtin.SetSubtensor
  549. class IncrSubtensor(IndexingBase):
  550. type = "IncrSubtensor"
  551. opdef = builtin.IncrSubtensor
  552. class IndexingMultiAxisVec(IndexingBase):
  553. type = "IndexingMultiAxisVec"
  554. opdef = builtin.IndexingMultiAxisVec
  555. class IndexingSetMultiAxisVec(IndexingBase):
  556. type = "IndexingSetMultiAxisVec"
  557. opdef = builtin.IndexingSetMultiAxisVec
  558. class IndexingIncrMultiAxisVec(IndexingBase):
  559. type = "IndexingIncrMultiAxisVec"
  560. opdef = builtin.IndexingIncrMultiAxisVec
  561. class MeshIndexing(IndexingBase):
  562. type = "MeshIndexing"
  563. opdef = builtin.MeshIndexing
  564. class SetMeshIndexing(IndexingBase):
  565. type = "SetMeshIndexing"
  566. opdef = builtin.SetMeshIndexing
  567. class IncrMeshIndexing(IndexingBase):
  568. type = "IncrMeshIndexing"
  569. opdef = builtin.IncrMeshIndexing
  570. class BatchedMeshIndexing(IndexingBase):
  571. type = "BatchedMeshIndexing"
  572. opdef = builtin.BatchedMeshIndexing
  573. class BatchedSetMeshIndexing(IndexingBase):
  574. type = "BatchedSetMeshIndexing"
  575. opdef = builtin.BatchedSetMeshIndexing
  576. class BatchedIncrMeshIndexing(IndexingBase):
  577. type = "BatchedIncrMeshIndexing"
  578. opdef = builtin.BatchedIncrMeshIndexing
  579. # class CollectiveComm
  580. # class RemoteSend
  581. # class RemoteRecv
  582. # class TQT
  583. # class FakeQuant
  584. # class InplaceAdd
  585. class AssertEqual(OpNode):
  586. type = "AssertEqual"
  587. opdef = builtin.AssertEqual
  588. class CvtColorForward(OpNode):
  589. type = "CvtColor"
  590. opdef = builtin.CvtColor