You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

layers.py 18 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765
  1. from abc import abstractmethod
  2. from collections.abc import Iterable
  3. import torch
  4. from torch import nn
  5. from torch.nn import functional
  6. from utils import Constant
  7. class AvgPool(nn.Module):
  8. """
  9. AvgPool Module.
  10. """
  11. def __init__(self):
  12. super().__init__()
  13. @abstractmethod
  14. def forward(self, input_tensor):
  15. pass
  16. class GlobalAvgPool1d(AvgPool):
  17. """
  18. GlobalAvgPool1d Module.
  19. """
  20. def forward(self, input_tensor):
  21. return functional.avg_pool1d(input_tensor, input_tensor.size()[2:]).view(
  22. input_tensor.size()[:2]
  23. )
  24. class GlobalAvgPool2d(AvgPool):
  25. """
  26. GlobalAvgPool2d Module.
  27. """
  28. def forward(self, input_tensor):
  29. return functional.avg_pool2d(input_tensor, input_tensor.size()[2:]).view(
  30. input_tensor.size()[:2]
  31. )
  32. class GlobalAvgPool3d(AvgPool):
  33. """
  34. GlobalAvgPool3d Module.
  35. """
  36. def forward(self, input_tensor):
  37. return functional.avg_pool3d(input_tensor, input_tensor.size()[2:]).view(
  38. input_tensor.size()[:2]
  39. )
  40. class StubLayer:
  41. """
  42. StubLayer Module. Base Module.
  43. """
  44. def __init__(self, input_node=None, output_node=None):
  45. self.input = input_node
  46. self.output = output_node
  47. self.weights = None
  48. def build(self, shape):
  49. """
  50. build shape.
  51. """
  52. def set_weights(self, weights):
  53. """
  54. set weights.
  55. """
  56. self.weights = weights
  57. def import_weights(self, torch_layer):
  58. """
  59. import weights.
  60. """
  61. def export_weights(self, torch_layer):
  62. """
  63. export weights.
  64. """
  65. def get_weights(self):
  66. """
  67. get weights.
  68. """
  69. return self.weights
  70. def size(self):
  71. """
  72. size().
  73. """
  74. return 0
  75. @property
  76. def output_shape(self):
  77. """
  78. output shape.
  79. """
  80. return self.input.shape
  81. def to_real_layer(self):
  82. """
  83. to real layer.
  84. """
  85. def __str__(self):
  86. """
  87. str() function to print.
  88. """
  89. return type(self).__name__[4:]
  90. class StubWeightBiasLayer(StubLayer):
  91. """
  92. StubWeightBiasLayer Module to set the bias.
  93. """
  94. def import_weights(self, torch_layer):
  95. self.set_weights(
  96. (torch_layer.weight.data.cpu().numpy(),
  97. torch_layer.bias.data.cpu().numpy())
  98. )
  99. def export_weights(self, torch_layer):
  100. torch_layer.weight.data = torch.Tensor(self.weights[0])
  101. torch_layer.bias.data = torch.Tensor(self.weights[1])
  102. class StubBatchNormalization(StubWeightBiasLayer):
  103. """
  104. StubBatchNormalization Module. Batch Norm.
  105. """
  106. def __init__(self, num_features, input_node=None, output_node=None):
  107. super().__init__(input_node, output_node)
  108. self.num_features = num_features
  109. def import_weights(self, torch_layer):
  110. self.set_weights(
  111. (
  112. torch_layer.weight.data.cpu().numpy(),
  113. torch_layer.bias.data.cpu().numpy(),
  114. torch_layer.running_mean.cpu().numpy(),
  115. torch_layer.running_var.cpu().numpy(),
  116. )
  117. )
  118. def export_weights(self, torch_layer):
  119. torch_layer.weight.data = torch.Tensor(self.weights[0])
  120. torch_layer.bias.data = torch.Tensor(self.weights[1])
  121. torch_layer.running_mean = torch.Tensor(self.weights[2])
  122. torch_layer.running_var = torch.Tensor(self.weights[3])
  123. def size(self):
  124. return self.num_features * 4
  125. @abstractmethod
  126. def to_real_layer(self):
  127. pass
  128. class StubBatchNormalization1d(StubBatchNormalization):
  129. """
  130. StubBatchNormalization1d Module.
  131. """
  132. def to_real_layer(self):
  133. return torch.nn.BatchNorm1d(self.num_features)
  134. class StubBatchNormalization2d(StubBatchNormalization):
  135. """
  136. StubBatchNormalization2d Module.
  137. """
  138. def to_real_layer(self):
  139. return torch.nn.BatchNorm2d(self.num_features)
  140. class StubBatchNormalization3d(StubBatchNormalization):
  141. """
  142. StubBatchNormalization3d Module.
  143. """
  144. def to_real_layer(self):
  145. return torch.nn.BatchNorm3d(self.num_features)
  146. class StubDense(StubWeightBiasLayer):
  147. """
  148. StubDense Module. Linear.
  149. """
  150. def __init__(self, input_units, units, input_node=None, output_node=None):
  151. super().__init__(input_node, output_node)
  152. self.input_units = input_units
  153. self.units = units
  154. @property
  155. def output_shape(self):
  156. return (self.units,)
  157. def size(self):
  158. return self.input_units * self.units + self.units
  159. def to_real_layer(self):
  160. return torch.nn.Linear(self.input_units, self.units)
  161. class StubConv(StubWeightBiasLayer):
  162. """
  163. StubConv Module. Conv.
  164. """
  165. def __init__(self, input_channel, filters, kernel_size,
  166. stride=1, input_node=None, output_node=None):
  167. super().__init__(input_node, output_node)
  168. self.input_channel = input_channel
  169. self.filters = filters
  170. self.kernel_size = kernel_size
  171. self.stride = stride
  172. self.padding = int(self.kernel_size / 2)
  173. @property
  174. def output_shape(self):
  175. ret = list(self.input.shape[:-1])
  176. for index, dim in enumerate(ret):
  177. ret[index] = (
  178. int((dim + 2 * self.padding - self.kernel_size) / self.stride) + 1
  179. )
  180. ret = ret + [self.filters]
  181. return tuple(ret)
  182. def size(self):
  183. return (self.input_channel * self.kernel_size *
  184. self.kernel_size + 1) * self.filters
  185. @abstractmethod
  186. def to_real_layer(self):
  187. pass
  188. def __str__(self):
  189. return (
  190. super().__str__()
  191. + "("
  192. + ", ".join(
  193. str(item)
  194. for item in [
  195. self.input_channel,
  196. self.filters,
  197. self.kernel_size,
  198. self.stride,
  199. ]
  200. )
  201. + ")"
  202. )
  203. class StubConv1d(StubConv):
  204. """
  205. StubConv1d Module.
  206. """
  207. def to_real_layer(self):
  208. return torch.nn.Conv1d(
  209. self.input_channel,
  210. self.filters,
  211. self.kernel_size,
  212. stride=self.stride,
  213. padding=self.padding,
  214. )
  215. class StubConv2d(StubConv):
  216. """
  217. StubConv2d Module.
  218. """
  219. def to_real_layer(self):
  220. return torch.nn.Conv2d(
  221. self.input_channel,
  222. self.filters,
  223. self.kernel_size,
  224. stride=self.stride,
  225. padding=self.padding,
  226. )
  227. class StubConv3d(StubConv):
  228. """
  229. StubConv3d Module.
  230. """
  231. def to_real_layer(self):
  232. return torch.nn.Conv3d(
  233. self.input_channel,
  234. self.filters,
  235. self.kernel_size,
  236. stride=self.stride,
  237. padding=self.padding,
  238. )
  239. class StubAggregateLayer(StubLayer):
  240. """
  241. StubAggregateLayer Module.
  242. """
  243. def __init__(self, input_nodes=None, output_node=None):
  244. if input_nodes is None:
  245. input_nodes = []
  246. super().__init__(input_nodes, output_node)
  247. class StubConcatenate(StubAggregateLayer):
  248. """StubConcatenate Module.
  249. """
  250. @property
  251. def output_shape(self):
  252. ret = 0
  253. for current_input in self.input:
  254. ret += current_input.shape[-1]
  255. ret = self.input[0].shape[:-1] + (ret,)
  256. return ret
  257. def to_real_layer(self):
  258. return TorchConcatenate()
  259. class StubAdd(StubAggregateLayer):
  260. """
  261. StubAdd Module.
  262. """
  263. @property
  264. def output_shape(self):
  265. return self.input[0].shape
  266. def to_real_layer(self):
  267. return TorchAdd()
  268. class StubFlatten(StubLayer):
  269. """
  270. StubFlatten Module.
  271. """
  272. @property
  273. def output_shape(self):
  274. ret = 1
  275. for dim in self.input.shape:
  276. ret *= dim
  277. return (ret,)
  278. def to_real_layer(self):
  279. return TorchFlatten()
  280. class StubReLU(StubLayer):
  281. """
  282. StubReLU Module.
  283. """
  284. def to_real_layer(self):
  285. return torch.nn.ReLU()
  286. class StubSoftmax(StubLayer):
  287. """
  288. StubSoftmax Module.
  289. """
  290. def to_real_layer(self):
  291. return torch.nn.LogSoftmax(dim=1)
  292. class StubDropout(StubLayer):
  293. """
  294. StubDropout Module.
  295. """
  296. def __init__(self, rate, input_node=None, output_node=None):
  297. super().__init__(input_node, output_node)
  298. self.rate = rate
  299. @abstractmethod
  300. def to_real_layer(self):
  301. pass
  302. class StubDropout1d(StubDropout):
  303. """
  304. StubDropout1d Module.
  305. """
  306. def to_real_layer(self):
  307. return torch.nn.Dropout(self.rate)
  308. class StubDropout2d(StubDropout):
  309. """
  310. StubDropout2d Module.
  311. """
  312. def to_real_layer(self):
  313. return torch.nn.Dropout2d(self.rate)
  314. class StubDropout3d(StubDropout):
  315. """
  316. StubDropout3d Module.
  317. """
  318. def to_real_layer(self):
  319. return torch.nn.Dropout3d(self.rate)
  320. class StubInput(StubLayer):
  321. """
  322. StubInput Module.
  323. """
  324. def __init__(self, input_node=None, output_node=None):
  325. super().__init__(input_node, output_node)
  326. class StubPooling(StubLayer):
  327. """
  328. StubPooling Module.
  329. """
  330. def __init__(self,
  331. kernel_size=None,
  332. stride=None,
  333. padding=0,
  334. input_node=None,
  335. output_node=None):
  336. super().__init__(input_node, output_node)
  337. self.kernel_size = (
  338. kernel_size if kernel_size is not None else Constant.POOLING_KERNEL_SIZE
  339. )
  340. self.stride = stride if stride is not None else self.kernel_size
  341. self.padding = padding
  342. @property
  343. def output_shape(self):
  344. ret = tuple()
  345. for dim in self.input.shape[:-1]:
  346. ret = ret + (max(int((dim + 2 * self.padding) / self.kernel_size), 1),)
  347. ret = ret + (self.input.shape[-1],)
  348. return ret
  349. @abstractmethod
  350. def to_real_layer(self):
  351. pass
  352. class StubPooling1d(StubPooling):
  353. """
  354. StubPooling1d Module.
  355. """
  356. def to_real_layer(self):
  357. return torch.nn.MaxPool1d(self.kernel_size, stride=self.stride)
  358. class StubPooling2d(StubPooling):
  359. """
  360. StubPooling2d Module.
  361. """
  362. def to_real_layer(self):
  363. return torch.nn.MaxPool2d(self.kernel_size, stride=self.stride)
  364. class StubPooling3d(StubPooling):
  365. """
  366. StubPooling3d Module.
  367. """
  368. def to_real_layer(self):
  369. return torch.nn.MaxPool3d(self.kernel_size, stride=self.stride)
  370. class StubGlobalPooling(StubLayer):
  371. """
  372. StubGlobalPooling Module.
  373. """
  374. def __init__(self, input_node=None, output_node=None):
  375. super().__init__(input_node, output_node)
  376. @property
  377. def output_shape(self):
  378. return (self.input.shape[-1],)
  379. @abstractmethod
  380. def to_real_layer(self):
  381. pass
  382. class StubGlobalPooling1d(StubGlobalPooling):
  383. """
  384. StubGlobalPooling1d Module.
  385. """
  386. def to_real_layer(self):
  387. return GlobalAvgPool1d()
  388. class StubGlobalPooling2d(StubGlobalPooling):
  389. """
  390. StubGlobalPooling2d Module.
  391. """
  392. def to_real_layer(self):
  393. return GlobalAvgPool2d()
  394. class StubGlobalPooling3d(StubGlobalPooling):
  395. """
  396. StubGlobalPooling3d Module.
  397. """
  398. def to_real_layer(self):
  399. return GlobalAvgPool3d()
  400. class TorchConcatenate(nn.Module):
  401. """
  402. TorchConcatenate Module.
  403. """
  404. def forward(self, input_list):
  405. return torch.cat(input_list, dim=1)
  406. class TorchAdd(nn.Module):
  407. """
  408. TorchAdd Module.
  409. """
  410. def forward(self, input_list):
  411. return input_list[0] + input_list[1]
  412. class TorchFlatten(nn.Module):
  413. """
  414. TorchFlatten Module.
  415. """
  416. def forward(self, input_tensor):
  417. return input_tensor.view(input_tensor.size(0), -1)
  418. def is_layer(layer, layer_type):
  419. """
  420. Judge the layer type.
  421. Returns
  422. -------
  423. bool
  424. boolean -- True or False
  425. """
  426. if layer_type == "Input":
  427. return isinstance(layer, StubInput)
  428. elif layer_type == "Conv":
  429. return isinstance(layer, StubConv)
  430. elif layer_type == "Dense":
  431. return isinstance(layer, (StubDense,))
  432. elif layer_type == "BatchNormalization":
  433. return isinstance(layer, (StubBatchNormalization,))
  434. elif layer_type == "Concatenate":
  435. return isinstance(layer, (StubConcatenate,))
  436. elif layer_type == "Add":
  437. return isinstance(layer, (StubAdd,))
  438. elif layer_type == "Pooling":
  439. return isinstance(layer, StubPooling)
  440. elif layer_type == "Dropout":
  441. return isinstance(layer, (StubDropout,))
  442. elif layer_type == "Softmax":
  443. return isinstance(layer, (StubSoftmax,))
  444. elif layer_type == "ReLU":
  445. return isinstance(layer, (StubReLU,))
  446. elif layer_type == "Flatten":
  447. return isinstance(layer, (StubFlatten,))
  448. elif layer_type == "GlobalAveragePooling":
  449. return isinstance(layer, StubGlobalPooling)
  450. return None # note: this is not written by original author, feel free to modify if you think it's incorrect
  451. def layer_description_extractor(layer, node_to_id):
  452. """
  453. Get layer description.
  454. """
  455. layer_input = layer.input
  456. layer_output = layer.output
  457. if layer_input is not None:
  458. if isinstance(layer_input, Iterable):
  459. layer_input = list(map(lambda x: node_to_id[x], layer_input))
  460. else:
  461. layer_input = node_to_id[layer_input]
  462. if layer_output is not None:
  463. layer_output = node_to_id[layer_output]
  464. if isinstance(layer, StubConv):
  465. return (
  466. type(layer).__name__,
  467. layer_input,
  468. layer_output,
  469. layer.input_channel,
  470. layer.filters,
  471. layer.kernel_size,
  472. layer.stride,
  473. layer.padding,
  474. )
  475. elif isinstance(layer, (StubDense,)):
  476. return [
  477. type(layer).__name__,
  478. layer_input,
  479. layer_output,
  480. layer.input_units,
  481. layer.units,
  482. ]
  483. elif isinstance(layer, (StubBatchNormalization,)):
  484. return (type(layer).__name__, layer_input,
  485. layer_output, layer.num_features)
  486. elif isinstance(layer, (StubDropout,)):
  487. return (type(layer).__name__, layer_input, layer_output, layer.rate)
  488. elif isinstance(layer, StubPooling):
  489. return (
  490. type(layer).__name__,
  491. layer_input,
  492. layer_output,
  493. layer.kernel_size,
  494. layer.stride,
  495. layer.padding,
  496. )
  497. else:
  498. return (type(layer).__name__, layer_input, layer_output)
  499. def layer_description_builder(layer_information, id_to_node):
  500. """build layer from description.
  501. """
  502. layer_type = layer_information[0]
  503. layer_input_ids = layer_information[1]
  504. if isinstance(layer_input_ids, Iterable):
  505. layer_input = list(map(lambda x: id_to_node[x], layer_input_ids))
  506. else:
  507. layer_input = id_to_node[layer_input_ids]
  508. layer_output = id_to_node[layer_information[2]]
  509. if layer_type.startswith("StubConv"):
  510. input_channel = layer_information[3]
  511. filters = layer_information[4]
  512. kernel_size = layer_information[5]
  513. stride = layer_information[6]
  514. return globals()[layer_type](
  515. input_channel, filters, kernel_size, stride, layer_input, layer_output
  516. )
  517. elif layer_type.startswith("StubDense"):
  518. input_units = layer_information[3]
  519. units = layer_information[4]
  520. return globals()[layer_type](input_units, units, layer_input, layer_output)
  521. elif layer_type.startswith("StubBatchNormalization"):
  522. num_features = layer_information[3]
  523. return globals()[layer_type](num_features, layer_input, layer_output)
  524. elif layer_type.startswith("StubDropout"):
  525. rate = layer_information[3]
  526. return globals()[layer_type](rate, layer_input, layer_output)
  527. elif layer_type.startswith("StubPooling"):
  528. kernel_size = layer_information[3]
  529. stride = layer_information[4]
  530. padding = layer_information[5]
  531. return globals()[layer_type](kernel_size, stride, padding, layer_input, layer_output)
  532. else:
  533. return globals()[layer_type](layer_input, layer_output)
  534. def layer_width(layer):
  535. """
  536. Get layer width.
  537. """
  538. if is_layer(layer, "Dense"):
  539. return layer.units
  540. if is_layer(layer, "Conv"):
  541. return layer.filters
  542. raise TypeError("The layer should be either Dense or Conv layer.")
  543. def set_torch_weight_to_stub(torch_layer, stub_layer):
  544. stub_layer.import_weights(torch_layer)
  545. def set_stub_weight_to_torch(stub_layer, torch_layer):
  546. stub_layer.export_weights(torch_layer)
  547. def get_conv_class(n_dim):
  548. conv_class_list = [StubConv1d, StubConv2d, StubConv3d]
  549. return conv_class_list[n_dim - 1]
  550. def get_dropout_class(n_dim):
  551. dropout_class_list = [StubDropout1d, StubDropout2d, StubDropout3d]
  552. return dropout_class_list[n_dim - 1]
  553. def get_global_avg_pooling_class(n_dim):
  554. global_avg_pooling_class_list = [
  555. StubGlobalPooling1d,
  556. StubGlobalPooling2d,
  557. StubGlobalPooling3d,
  558. ]
  559. return global_avg_pooling_class_list[n_dim - 1]
  560. def get_pooling_class(n_dim):
  561. pooling_class_list = [StubPooling1d, StubPooling2d, StubPooling3d]
  562. return pooling_class_list[n_dim - 1]
  563. def get_batch_norm_class(n_dim):
  564. batch_norm_class_list = [
  565. StubBatchNormalization1d,
  566. StubBatchNormalization2d,
  567. StubBatchNormalization3d,
  568. ]
  569. return batch_norm_class_list[n_dim - 1]
  570. def get_n_dim(layer):
  571. if isinstance(layer, (
  572. StubConv1d,
  573. StubDropout1d,
  574. StubGlobalPooling1d,
  575. StubPooling1d,
  576. StubBatchNormalization1d,
  577. )):
  578. return 1
  579. if isinstance(layer, (
  580. StubConv2d,
  581. StubDropout2d,
  582. StubGlobalPooling2d,
  583. StubPooling2d,
  584. StubBatchNormalization2d,
  585. )):
  586. return 2
  587. if isinstance(layer, (
  588. StubConv3d,
  589. StubDropout3d,
  590. StubGlobalPooling3d,
  591. StubPooling3d,
  592. StubBatchNormalization3d,
  593. )):
  594. return 3
  595. return -1

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能