You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_node.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813
  1. # -*- coding: utf-8 -*-
  2. import json
  3. import sys
  4. from typing import Sequence
  5. import numpy as np
  6. from ..core import _imperative_rt as rt
  7. from ..core._imperative_rt.core2 import apply, set_py_varnode_type
  8. from ..core._trace_option import use_symbolic_shape
  9. from ..core._wrap import Device
  10. from ..core.ops import builtin
  11. from ..tensor import Tensor
  12. from .comp_graph_tools import replace_vars
  13. from .module_stats import (
  14. preprocess_receptive_field,
  15. register_flops,
  16. register_receptive_field,
  17. )
  18. class NetworkNode:
  19. pass
  20. class VarNode(NetworkNode, Tensor):
  21. _users = None
  22. _owner = None
  23. _name = None
  24. _id = None
  25. def __new__(cls, var, *, owner_opr=None, name=None):
  26. obj = Tensor.__new__(cls, var)
  27. return obj
  28. def __init__(self, var, *, owner_opr=None, name=None):
  29. self._owner = owner_opr
  30. self.name = name
  31. @classmethod
  32. def load(cls, sym_var, owner_opr):
  33. obj = cls(sym_var)
  34. obj.var = sym_var # mgb varnode
  35. obj.name = sym_var.name
  36. obj.owner = owner_opr
  37. return obj
  38. @property
  39. def users(self):
  40. if self._users is None:
  41. self._users = []
  42. return self._users
  43. @property
  44. def owner(self):
  45. return self._owner
  46. @owner.setter
  47. def owner(self, owner):
  48. self._owner = owner
  49. @property
  50. def id(self):
  51. if self._id is None:
  52. self._id = id(self)
  53. return self._id
  54. @property
  55. def var(self):
  56. return super().var()
  57. @var.setter
  58. def var(self, var):
  59. self._reset(var)
  60. def _reset(self, other):
  61. if not isinstance(other, Tensor):
  62. other = VarNode(other)
  63. super()._reset(other)
  64. self.owner = None
  65. def _reset_var(self, var):
  66. origin_owner = self.owner
  67. self.var = var
  68. self.var.name = self.name
  69. self.owner = origin_owner
  70. @property
  71. def graph(self):
  72. return super().graph()
  73. def _get_var_shape(self, axis=None):
  74. opdef = (
  75. builtin.GetVarShape() if axis is None else builtin.GetVarShape(axis=axis)
  76. )
  77. return apply(opdef, self)[0]
  78. @property
  79. def partial_shape(self):
  80. r"""Return the tuple type inferred shape of VarNode"""
  81. return tuple(self._get_var_shape().numpy())
  82. def shapeof(self, axis):
  83. r"""Return the symbolic shape of axis"""
  84. return self._get_var_shape(axis=axis) if self.var else None
  85. @property
  86. def _tuple_shape(self):
  87. return self.partial_shape
  88. @property
  89. def shape(self):
  90. r"""Return the symbolic shape if using set_symbolic_shape(True)
  91. else inferred shape
  92. """
  93. rst = None
  94. if self.var:
  95. try:
  96. rst = self.var.shape
  97. except:
  98. rst = None
  99. if not use_symbolic_shape():
  100. return rst
  101. return self._get_var_shape() if self.var else None
  102. def __bool__(self):
  103. return False
  104. __index__ = None
  105. __int__ = None
  106. __float__ = None
  107. __complex__ = None
  108. __repr__ = lambda self: "VarNode:" + self.name
  109. def __hash__(self):
  110. return id(self)
  111. def set_owner_opr(self, owner_opr):
  112. self.owner = owner_opr
  113. class OpNode(NetworkNode):
  114. opdef = None
  115. type = None
  116. def __init__(self):
  117. self.inputs = []
  118. self.outputs = []
  119. self.params = {}
  120. self._opr = None # mgb opnode
  121. @property
  122. def id(self):
  123. return id(self)
  124. @property
  125. def priority(self):
  126. if self._opr is not None:
  127. return (self._opr.priority, self._opr.id)
  128. return (0, 0)
  129. @classmethod
  130. def load(cls, opr):
  131. obj = cls()
  132. obj.params = json.loads(opr.params)
  133. obj.name = opr.name
  134. obj._opr = opr
  135. return obj
  136. def compile(self):
  137. if (
  138. self._opr is None
  139. or len(self._opr.inputs) != len(self.inputs)
  140. or any([i != j.var for i, j in zip(self._opr.inputs, self.inputs)])
  141. ):
  142. op = self.opdef(**self.params)
  143. args = [i.var for i in self.inputs]
  144. outputs = rt.invoke_op(op, args)
  145. assert len(outputs) == len(self.outputs)
  146. self._opr = outputs[0].owner
  147. for i in range(len(self.outputs)):
  148. self.outputs[i]._reset_var(outputs[i])
  149. assert self.outputs[i].owner is self
  150. def add_inp_var(self, x):
  151. self.inputs.append(x)
  152. def add_out_var(self, x):
  153. self.outputs.append(x)
  154. def __repr__(self):
  155. return "%s{%s}" % (self.name, self.type)
  156. def str_to_mge_class(classname):
  157. # TODO: use megbrain C++ RTTI to replace type string
  158. if classname == "RNGOpr<MegDNNOpr>":
  159. classname = "RNGOpr"
  160. oprcls = getattr(sys.modules[__name__], classname, None)
  161. return oprcls if oprcls else ReadOnlyOpNode
  162. class Host2DeviceCopy(OpNode):
  163. type = "Host2DeviceCopy"
  164. def __init__(self, shape=None, dtype=None, name=None, device=None):
  165. super().__init__()
  166. self.shape = shape
  167. self.dtype = dtype
  168. self.name = name
  169. self.device = Device(device).to_c() if device else Device("xpux").to_c()
  170. self.outputs = []
  171. @classmethod
  172. def load(cls, opr):
  173. self = cls()
  174. self.outputs = []
  175. assert len(opr.outputs) == 1, "wrong number of outputs"
  176. self.shape = opr.outputs[0].shape
  177. self.dtype = opr.outputs[0].dtype
  178. self.name = opr.outputs[0].name
  179. self.device = opr.outputs[0].comp_node
  180. self._opr = opr
  181. return self
  182. def compile(self, graph):
  183. if (
  184. self._opr is None
  185. or self._opr.graph != graph
  186. or self._opr.outputs[0].comp_node != self.device
  187. or self._opr.outputs[0].shape != self.shape
  188. or self._opr.outputs[0].dtype != self.dtype
  189. ):
  190. outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name)
  191. self._opr = outputs.owner
  192. if len(self.outputs) == 0:
  193. self.outputs.append(VarNode(outputs, owner_opr=self, name=self.name))
  194. else:
  195. self.outputs[0]._reset_var(outputs)
  196. assert self.outputs[0].owner is self
  197. class ConstOpBase(OpNode):
  198. type = "ConstOpBase"
  199. def __init__(self, data=None, name=None, device=None, graph=None):
  200. assert type(self) is not ConstOpBase, "ConstOpBase cannot be instantiated"
  201. super().__init__()
  202. self.name = name
  203. self.outputs = []
  204. self.graph = graph
  205. if data is not None:
  206. self.set_value(data, device)
  207. @property
  208. def device(self):
  209. return self._opr.outputs[0].comp_node if self._opr else None
  210. @device.setter
  211. def device(self, device):
  212. self.set_value(self.numpy(), device)
  213. @property
  214. def shape(self):
  215. return self.outputs[0].shape
  216. @property
  217. def dtype(self):
  218. return self._opr.outputs[0].dtype if self._opr else None
  219. def numpy(self):
  220. return self.outputs[0].numpy()
  221. def set_value(self, data, device=None):
  222. assert self.graph is not None
  223. cn = device if device else self.device
  224. assert isinstance(data, (int, float, Sequence, np.ndarray))
  225. if not isinstance(data, np.ndarray):
  226. data = np.array(data)
  227. if data.dtype == np.float64:
  228. data = data.astype(np.float32)
  229. elif data.dtype == np.int64:
  230. data = data.astype(np.int32)
  231. varnode = type(self).rt_fun(self.graph, data, cn, data.dtype, self.name)
  232. if len(self.outputs) == 0:
  233. self.outputs.append(VarNode(varnode, owner_opr=self, name=self.name))
  234. else:
  235. self.outputs[0]._reset_var(varnode)
  236. self._opr = varnode.owner
  237. @classmethod
  238. def load(cls, opr):
  239. self = cls()
  240. self.outputs = []
  241. self._opr = opr
  242. self.name = opr.outputs[0].name
  243. self.graph = opr.graph
  244. return self
  245. def compile(self, graph):
  246. assert self.outputs[0].var is self._opr.outputs[0]
  247. assert self.outputs[0].owner is self
  248. if self.graph != graph:
  249. self.graph = graph
  250. self.set_value(self.numpy())
  251. if self.name is not None:
  252. self.outputs[0].var.name = self.name
  253. class ImmutableTensor(ConstOpBase):
  254. type = "ImmutableTensor"
  255. rt_fun = rt.make_const
  256. class SharedDeviceTensor(ConstOpBase):
  257. type = "SharedDeviceTensor"
  258. rt_fun = rt.make_shared
  259. class ReadOnlyOpNode(OpNode):
  260. @classmethod
  261. def load(cls, opr):
  262. obj = super(ReadOnlyOpNode, cls).load(opr)
  263. obj.type = opr.type
  264. return obj
  265. def compile(self):
  266. assert self._opr is not None
  267. assert len(self.inputs) == len(self._opr.inputs)
  268. assert len(self.outputs) == len(self._opr.outputs)
  269. repl_dict = {}
  270. for ind, i in enumerate(self.inputs):
  271. if i.var != self._opr.inputs[ind]:
  272. repl_dict[self._opr.inputs[ind]] = i.var
  273. if bool(repl_dict):
  274. out_vars = replace_vars(self._opr.outputs, repl_dict)
  275. for ind, o in enumerate(self.outputs):
  276. o._reset_var(out_vars[ind])
  277. class Elemwise(OpNode):
  278. type = "Elemwise"
  279. opdef = builtin.Elemwise
  280. def __repr__(self):
  281. return "%s{Elemwise:%s}" % (self.name, self.params["mode"])
  282. class ElemwiseMultiType(OpNode):
  283. type = "ElemwiseMultiType"
  284. opdef = builtin.ElemwiseMultiType
  285. def __repr__(self):
  286. return "%s{ElemwiseMultiType:%s}" % (self.name, self.params["mode"])
  287. @classmethod
  288. def load(cls, opr):
  289. obj = super(ElemwiseMultiType, cls).load(opr)
  290. obj.params["dtype"] = opr.outputs[0].dtype
  291. return obj
  292. @register_flops(Elemwise, ElemwiseMultiType)
  293. def flops_elemwise(opnode: Elemwise, inputs, outputs):
  294. return np.prod(outputs[0].shape)
  295. class Reduce(OpNode):
  296. type = "Reduce"
  297. opdef = builtin.Reduce
  298. class TypeCvt(OpNode):
  299. type = "TypeCvt"
  300. opdef = builtin.TypeCvt
  301. @classmethod
  302. def load(cls, opr):
  303. obj = super(TypeCvt, cls).load(opr)
  304. t_dtype = opr.outputs[0].dtype
  305. obj.params["dtype"] = t_dtype
  306. return obj
  307. class MatrixInverse(OpNode):
  308. type = "MatrixInverse"
  309. opdef = builtin.MatrixInverse
  310. class MatrixMul(OpNode):
  311. type = "MatrixMul"
  312. opdef = builtin.MatrixMul
  313. @register_flops(MatrixMul)
  314. def flops_matmul(opnode: MatrixMul, inputs, outputs):
  315. assert len(inputs[0].shape) == 2 and len(outputs[0].shape) == 2
  316. mid_shape = inputs[0].shape[1]
  317. return np.prod(outputs[0].shape) * mid_shape
  318. class BatchedMatrixMul(OpNode):
  319. type = "BatchedMatmul"
  320. opdef = builtin.BatchedMatrixMul
  321. @register_flops(BatchedMatrixMul)
  322. def flops_batchmatmul(opnode: BatchedMatrixMul, inputs, outputs):
  323. assert len(inputs[0].shape) == 3 and len(outputs[0].shape) == 3
  324. mid_shape = inputs[0].shape[2]
  325. return np.prod(outputs[0].shape) * mid_shape
  326. class Dot(OpNode):
  327. type = "Dot"
  328. opdef = builtin.Dot
  329. class SVD(OpNode):
  330. type = "SVD"
  331. opdef = builtin.SVD
  332. class ConvolutionForward(OpNode):
  333. type = "Convolution"
  334. opdef = builtin.Convolution
  335. class ConvolutionBackwardData(OpNode):
  336. type = "ConvTranspose"
  337. opdef = builtin.ConvolutionBackwardData
  338. class DeformableConvForward(OpNode):
  339. type = "DeformableConv"
  340. opdef = builtin.DeformableConv
  341. class GroupLocalForward(OpNode):
  342. type = "GroupLocal"
  343. opdef = builtin.GroupLocal
  344. class PoolingForward(OpNode):
  345. type = "Pooling"
  346. opdef = builtin.Pooling
  347. class AdaptivePoolingForward(OpNode):
  348. type = "AdaptivePooling"
  349. opdef = builtin.AdaptivePooling
  350. class ROIPoolingForward(OpNode):
  351. type = "ROIPooling"
  352. opdef = builtin.ROIPooling
  353. class DeformablePSROIPoolingForward(OpNode):
  354. type = "DeformablePSROIPooling"
  355. opdef = builtin.DeformablePSROIPooling
  356. class ConvBiasForward(OpNode):
  357. type = "ConvBias"
  358. opdef = builtin.ConvBias
  359. @classmethod
  360. def load(cls, opr):
  361. obj = super(ConvBiasForward, cls).load(opr)
  362. obj.params["dtype"] = opr.outputs[0].dtype
  363. return obj
  364. @register_flops(
  365. ConvolutionForward, ConvBiasForward,
  366. )
  367. def flops_conv(opnode: ConvolutionForward, inputs, outputs):
  368. param_W_shape = inputs[1].shape
  369. kh = param_W_shape[-2]
  370. kw = param_W_shape[-1]
  371. if len(param_W_shape) == 5:
  372. num_input = param_W_shape[2]
  373. else:
  374. num_input = param_W_shape[1]
  375. NCHW = np.prod(outputs[0].shape)
  376. bias = 1 if isinstance(opnode, ConvBiasForward) else 0
  377. # N x Cout x H x W x (Cin x Kw x Kh)
  378. return NCHW * (float(num_input * kw * kh) + bias)
  379. @register_receptive_field(ConvolutionForward, ConvBiasForward)
  380. def receptive_field(opnode: ConvolutionForward, inputs, outputs):
  381. pre_rf, pre_stride = preprocess_receptive_field(opnode, inputs, outputs)
  382. param_W_shape = inputs[1].shape
  383. kh = param_W_shape[-2]
  384. kw = param_W_shape[-1]
  385. rf = (
  386. kh * pre_stride[0] + pre_rf[0] - pre_stride[0],
  387. kw * pre_stride[1] + pre_rf[1] - pre_stride[1],
  388. )
  389. stride = (
  390. opnode.params["stride_h"] * pre_stride[0],
  391. opnode.params["stride_w"] * pre_stride[1],
  392. )
  393. opnode._rf = rf
  394. opnode._stride = stride
  395. return rf, stride
  396. class BatchConvBiasForward(OpNode):
  397. type = "BatchConvBias"
  398. opdef = builtin.BatchConvBias
  399. @classmethod
  400. def load(cls, opr):
  401. obj = super(BatchConvBiasForward, cls).load(opr)
  402. obj.params["dtype"] = opr.outputs[0].dtype
  403. return obj
  404. class BatchNormForward(OpNode):
  405. type = "BatchNorm"
  406. opdef = builtin.BatchNorm
  407. output_idx = -1
  408. class ROIAlignForward(OpNode):
  409. type = "ROIAlign"
  410. opdef = builtin.ROIAlign
  411. class WarpPerspectiveForward(OpNode):
  412. type = "WarpPerspective"
  413. opdef = builtin.WarpPerspective
  414. class WarpAffineForward(OpNode):
  415. type = "WarpAffine"
  416. opdef = builtin.WarpAffine
  417. class RemapForward(OpNode):
  418. type = "Remap"
  419. opdef = builtin.Remap
  420. class ResizeForward(OpNode):
  421. type = "Resize"
  422. opdef = builtin.Resize
  423. class IndexingOneHot(OpNode):
  424. type = "IndexingOneHot"
  425. opdef = builtin.IndexingOneHot
  426. class IndexingSetOneHot(OpNode):
  427. type = "IndexingSetOneHot"
  428. opdef = builtin.IndexingSetOneHot
  429. class Copy(OpNode):
  430. type = "Copy"
  431. opdef = builtin.Copy
  432. @classmethod
  433. def load(cls, opr):
  434. obj = super(Copy, cls).load(opr)
  435. obj.params["comp_node"] = opr.outputs[0].comp_node
  436. return obj
  437. class ArgsortForward(OpNode):
  438. type = "Argsort"
  439. opdef = builtin.Argsort
  440. class Argmax(OpNode):
  441. type = "Argmax"
  442. opdef = builtin.Argmax
  443. class Argmin(OpNode):
  444. type = "Argmin"
  445. opdef = builtin.Argmin
  446. class CondTake(OpNode):
  447. type = "CondTake"
  448. opdef = builtin.CondTake
  449. class TopK(OpNode):
  450. type = "TopK"
  451. opdef = builtin.TopK
  452. class NvOf(OpNode):
  453. type = "NvOf"
  454. opdef = builtin.NvOf
  455. class RNGOpr(OpNode):
  456. @classmethod
  457. def load(cls, opr):
  458. obj = super(RNGOpr, cls).load(opr)
  459. if len(obj.params) == 3:
  460. obj.opdef = builtin.GaussianRNG
  461. obj.type = "GaussianRNG"
  462. else:
  463. obj.opdef = builtin.UniformRNG
  464. obj.type = "UniformRNG"
  465. return obj
  466. class Linspace(OpNode):
  467. type = "Linspace"
  468. opdef = builtin.Linspace
  469. @classmethod
  470. def load(cls, opr):
  471. obj = super(Linspace, cls).load(opr)
  472. obj.params["comp_node"] = opr.outputs[0].comp_node
  473. return obj
  474. class Eye(OpNode):
  475. type = "Eye"
  476. opdef = builtin.Eye
  477. @classmethod
  478. def load(cls, opr):
  479. obj = super(Eye, cls).load(opr)
  480. obj.params["dtype"] = opr.outputs[0].dtype
  481. obj.params["comp_node"] = opr.outputs[0].comp_node
  482. return obj
  483. class GetVarShape(OpNode):
  484. type = "GetVarShape"
  485. opdef = builtin.GetVarShape
  486. class Concat(OpNode):
  487. type = "Concat"
  488. opdef = builtin.Concat
  489. @classmethod
  490. def load(cls, opr):
  491. obj = super(Concat, cls).load(opr)
  492. obj.params["comp_node"] = Device("xpux").to_c()
  493. return obj
  494. class Broadcast(OpNode):
  495. type = "Broadcast"
  496. opdef = builtin.Broadcast
  497. class Identity(OpNode):
  498. type = "Identity"
  499. opdef = builtin.Identity
  500. class NMSKeep(OpNode):
  501. type = "NMSKeep"
  502. opdef = builtin.NMSKeep
  503. # class ParamPackSplit
  504. # class ParamPackConcat
  505. class Dimshuffle(OpNode):
  506. type = "Dimshuffle"
  507. opdef = builtin.Dimshuffle
  508. @classmethod
  509. def load(cls, opr):
  510. obj = super(Dimshuffle, cls).load(opr)
  511. del obj.params["ndim"]
  512. return obj
  513. class Reshape(OpNode):
  514. type = "Reshape"
  515. opdef = builtin.Reshape
  516. class AxisAddRemove(OpNode):
  517. type = "AxisAddRemove"
  518. @classmethod
  519. def load(cls, opr):
  520. obj = cls()
  521. obj.name = opr.name
  522. obj._opr = opr
  523. params = json.loads(opr.params)
  524. desc = params["desc"]
  525. method = None
  526. axis = []
  527. for i in desc:
  528. if method is None:
  529. method = i["method"]
  530. assert method == i["method"]
  531. axis.append(i["axisnum"])
  532. obj.params = {"axis": axis}
  533. obj.opdef = builtin.AddAxis if desc[0]["method"] == 0 else builtin.RemoveAxis
  534. return obj
  535. class IndexingBase(OpNode):
  536. @classmethod
  537. def load(cls, opr):
  538. obj = cls()
  539. obj.name = opr.name
  540. obj._opr = opr
  541. params = json.loads(opr.params)
  542. items = [
  543. [
  544. p["axis"],
  545. bool(p["begin"]),
  546. bool(p["end"]),
  547. bool(p["step"]),
  548. bool(p["idx"]),
  549. ]
  550. for p in params
  551. ]
  552. obj.params["items"] = items
  553. return obj
  554. class Subtensor(IndexingBase):
  555. type = "Subtensor"
  556. opdef = builtin.Subtensor
  557. class SetSubtensor(IndexingBase):
  558. type = "SetSubtensor"
  559. opdef = builtin.SetSubtensor
  560. class IncrSubtensor(IndexingBase):
  561. type = "IncrSubtensor"
  562. opdef = builtin.IncrSubtensor
  563. class IndexingMultiAxisVec(IndexingBase):
  564. type = "IndexingMultiAxisVec"
  565. opdef = builtin.IndexingMultiAxisVec
  566. class IndexingSetMultiAxisVec(IndexingBase):
  567. type = "IndexingSetMultiAxisVec"
  568. opdef = builtin.IndexingSetMultiAxisVec
  569. class IndexingIncrMultiAxisVec(IndexingBase):
  570. type = "IndexingIncrMultiAxisVec"
  571. opdef = builtin.IndexingIncrMultiAxisVec
  572. class MeshIndexing(IndexingBase):
  573. type = "MeshIndexing"
  574. opdef = builtin.MeshIndexing
  575. class SetMeshIndexing(IndexingBase):
  576. type = "SetMeshIndexing"
  577. opdef = builtin.SetMeshIndexing
  578. class IncrMeshIndexing(IndexingBase):
  579. type = "IncrMeshIndexing"
  580. opdef = builtin.IncrMeshIndexing
  581. class BatchedMeshIndexing(IndexingBase):
  582. type = "BatchedMeshIndexing"
  583. opdef = builtin.BatchedMeshIndexing
  584. class BatchedSetMeshIndexing(IndexingBase):
  585. type = "BatchedSetMeshIndexing"
  586. opdef = builtin.BatchedSetMeshIndexing
  587. class BatchedIncrMeshIndexing(IndexingBase):
  588. type = "BatchedIncrMeshIndexing"
  589. opdef = builtin.BatchedIncrMeshIndexing
  590. # class CollectiveComm
  591. # class RemoteSend
  592. # class RemoteRecv
  593. # class TQT
  594. # class FakeQuant
  595. # class InplaceAdd
  596. class AssertEqual(OpNode):
  597. type = "AssertEqual"
  598. opdef = builtin.AssertEqual
  599. class CvtColorForward(OpNode):
  600. type = "CvtColor"
  601. opdef = builtin.CvtColor
  602. set_py_varnode_type(VarNode)