You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph.py 36 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919
  1. import json
  2. from collections.abc import Iterable
  3. from copy import deepcopy, copy
  4. from queue import Queue
  5. import numpy as np
  6. import torch
  7. from .layer_transformer import (
  8. add_noise,
  9. wider_bn,
  10. wider_next_conv,
  11. wider_next_dense,
  12. wider_pre_conv,
  13. wider_pre_dense,
  14. init_dense_weight,
  15. init_conv_weight,
  16. init_bn_weight,
  17. )
  18. from .layers import (
  19. StubAdd,
  20. StubConcatenate,
  21. StubReLU,
  22. get_batch_norm_class,
  23. get_conv_class,
  24. is_layer,
  25. layer_width,
  26. set_stub_weight_to_torch,
  27. set_torch_weight_to_stub,
  28. layer_description_extractor,
  29. layer_description_builder,
  30. )
  31. from utils import Constant
  32. class NetworkDescriptor:
  33. """A class describing the neural architecture for neural network kernel.
  34. It only record the width of convolutional and dense layers, and the skip-connection types and positions.
  35. """
  36. CONCAT_CONNECT = "concat"
  37. ADD_CONNECT = "add"
  38. def __init__(self):
  39. self.skip_connections = []
  40. self.layers = []
  41. @property
  42. def n_layers(self):
  43. return len(self.layers)
  44. def add_skip_connection(self, u, v, connection_type):
  45. """ Add a skip-connection to the descriptor.
  46. Args:
  47. u: Number of convolutional layers before the starting point.
  48. v: Number of convolutional layers before the ending point.
  49. connection_type: Must be either CONCAT_CONNECT or ADD_CONNECT.
  50. """
  51. if connection_type not in [self.CONCAT_CONNECT, self.ADD_CONNECT]:
  52. raise ValueError(
  53. "connection_type should be NetworkDescriptor.CONCAT_CONNECT "
  54. "or NetworkDescriptor.ADD_CONNECT."
  55. )
  56. self.skip_connections.append((u, v, connection_type))
  57. def to_json(self):
  58. ''' NetworkDescriptor to json representation
  59. '''
  60. skip_list = []
  61. for u, v, connection_type in self.skip_connections:
  62. skip_list.append({"from": u, "to": v, "type": connection_type})
  63. return {"node_list": self.layers, "skip_list": skip_list}
  64. def add_layer(self, layer):
  65. ''' add one layer
  66. '''
  67. self.layers.append(layer)
  68. class Node:
  69. """A class for intermediate output tensor (node) in the Graph.
  70. Attributes:
  71. shape: A tuple describing the shape of the tensor.
  72. """
  73. def __init__(self, shape):
  74. self.shape = shape
  75. class Graph:
  76. """A class representing the neural architecture graph of a model.
  77. Graph extracts the neural architecture graph from a model.
  78. Each node in the graph is a intermediate tensor between layers.
  79. Each layer is an edge in the graph.
  80. Notably, multiple edges may refer to the same layer.
  81. (e.g. Add layer is adding two tensor into one tensor. So it is related to two edges.)
  82. Attributes:
  83. weighted: A boolean of whether the weights and biases in the neural network
  84. should be included in the graph.
  85. input_shape: A tuple of integers, which does not include the batch axis.
  86. node_list: A list of integers. The indices of the list are the identifiers.
  87. layer_list: A list of stub layers. The indices of the list are the identifiers.
  88. node_to_id: A dict instance mapping from node integers to their identifiers.
  89. layer_to_id: A dict instance mapping from stub layers to their identifiers.
  90. layer_id_to_input_node_ids: A dict instance mapping from layer identifiers
  91. to their input nodes identifiers.
  92. layer_id_to_output_node_ids: A dict instance mapping from layer identifiers
  93. to their output nodes identifiers.
  94. adj_list: A two dimensional list. The adjacency list of the graph. The first dimension is
  95. identified by tensor identifiers. In each edge list, the elements are two-element tuples
  96. of (tensor identifier, layer identifier).
  97. reverse_adj_list: A reverse adjacent list in the same format as adj_list.
  98. operation_history: A list saving all the network morphism operations.
  99. vis: A dictionary of temporary storage for whether an local operation has been done
  100. during the network morphism.
  101. """
  102. def __init__(self, input_shape, weighted=True):
  103. """Initializer for Graph.
  104. """
  105. self.input_shape = input_shape
  106. self.weighted = weighted
  107. self.node_list = []
  108. self.layer_list = []
  109. # node id start with 0
  110. self.node_to_id = {}
  111. self.layer_to_id = {}
  112. self.layer_id_to_input_node_ids = {}
  113. self.layer_id_to_output_node_ids = {}
  114. self.adj_list = {}
  115. self.reverse_adj_list = {}
  116. self.operation_history = []
  117. self.n_dim = len(input_shape) - 1
  118. self.conv = get_conv_class(self.n_dim)
  119. self.batch_norm = get_batch_norm_class(self.n_dim)
  120. self.vis = None
  121. self._add_node(Node(input_shape))
  122. def add_layer(self, layer, input_node_id):
  123. """Add a layer to the Graph.
  124. Args:
  125. layer: An instance of the subclasses of StubLayer in layers.py.
  126. input_node_id: An integer. The ID of the input node of the layer.
  127. Returns:
  128. output_node_id: An integer. The ID of the output node of the layer.
  129. """
  130. if isinstance(input_node_id, Iterable):
  131. layer.input = list(map(lambda x: self.node_list[x], input_node_id))
  132. output_node_id = self._add_node(Node(layer.output_shape))
  133. for node_id in input_node_id:
  134. self._add_edge(layer, node_id, output_node_id)
  135. else:
  136. layer.input = self.node_list[input_node_id]
  137. output_node_id = self._add_node(Node(layer.output_shape))
  138. self._add_edge(layer, input_node_id, output_node_id)
  139. layer.output = self.node_list[output_node_id]
  140. return output_node_id
  141. def clear_operation_history(self):
  142. self.operation_history = []
  143. @property
  144. def n_nodes(self):
  145. """Return the number of nodes in the model."""
  146. return len(self.node_list)
  147. @property
  148. def n_layers(self):
  149. """Return the number of layers in the model."""
  150. return len(self.layer_list)
  151. def _add_node(self, node):
  152. """Add a new node to node_list and give the node an ID.
  153. Args:
  154. node: An instance of Node.
  155. Returns:
  156. node_id: An integer.
  157. """
  158. node_id = len(self.node_list)
  159. self.node_to_id[node] = node_id
  160. self.node_list.append(node)
  161. self.adj_list[node_id] = []
  162. self.reverse_adj_list[node_id] = []
  163. return node_id
  164. def _add_edge(self, layer, input_id, output_id):
  165. """Add a new layer to the graph. The nodes should be created in advance."""
  166. if layer in self.layer_to_id:
  167. layer_id = self.layer_to_id[layer]
  168. if input_id not in self.layer_id_to_input_node_ids[layer_id]:
  169. self.layer_id_to_input_node_ids[layer_id].append(input_id)
  170. if output_id not in self.layer_id_to_output_node_ids[layer_id]:
  171. self.layer_id_to_output_node_ids[layer_id].append(output_id)
  172. else:
  173. layer_id = len(self.layer_list)
  174. self.layer_list.append(layer)
  175. self.layer_to_id[layer] = layer_id
  176. self.layer_id_to_input_node_ids[layer_id] = [input_id]
  177. self.layer_id_to_output_node_ids[layer_id] = [output_id]
  178. self.adj_list[input_id].append((output_id, layer_id))
  179. self.reverse_adj_list[output_id].append((input_id, layer_id))
  180. def _redirect_edge(self, u_id, v_id, new_v_id):
  181. """Redirect the layer to a new node.
  182. Change the edge originally from `u_id` to `v_id` into an edge from `u_id` to `new_v_id`
  183. while keeping all other property of the edge the same.
  184. """
  185. layer_id = None
  186. for index, edge_tuple in enumerate(self.adj_list[u_id]):
  187. if edge_tuple[0] == v_id:
  188. layer_id = edge_tuple[1]
  189. self.adj_list[u_id][index] = (new_v_id, layer_id)
  190. self.layer_list[layer_id].output = self.node_list[new_v_id]
  191. break
  192. for index, edge_tuple in enumerate(self.reverse_adj_list[v_id]):
  193. if edge_tuple[0] == u_id:
  194. layer_id = edge_tuple[1]
  195. self.reverse_adj_list[v_id].remove(edge_tuple)
  196. break
  197. self.reverse_adj_list[new_v_id].append((u_id, layer_id))
  198. for index, value in enumerate(
  199. self.layer_id_to_output_node_ids[layer_id]):
  200. if value == v_id:
  201. self.layer_id_to_output_node_ids[layer_id][index] = new_v_id
  202. break
  203. def _replace_layer(self, layer_id, new_layer):
  204. """Replace the layer with a new layer."""
  205. old_layer = self.layer_list[layer_id]
  206. new_layer.input = old_layer.input
  207. new_layer.output = old_layer.output
  208. new_layer.output.shape = new_layer.output_shape
  209. self.layer_list[layer_id] = new_layer
  210. self.layer_to_id[new_layer] = layer_id
  211. self.layer_to_id.pop(old_layer)
  212. @property
  213. def topological_order(self):
  214. """Return the topological order of the node IDs from the input node to the output node."""
  215. q = Queue()
  216. in_degree = {}
  217. for i in range(self.n_nodes):
  218. in_degree[i] = 0
  219. for u in range(self.n_nodes):
  220. for v, _ in self.adj_list[u]:
  221. in_degree[v] += 1
  222. for i in range(self.n_nodes):
  223. if in_degree[i] == 0:
  224. q.put(i)
  225. order_list = []
  226. while not q.empty():
  227. u = q.get()
  228. order_list.append(u)
  229. for v, _ in self.adj_list[u]:
  230. in_degree[v] -= 1
  231. if in_degree[v] == 0:
  232. q.put(v)
  233. return order_list
  234. def _get_pooling_layers(self, start_node_id, end_node_id):
  235. """
  236. Given two node IDs, return all the pooling layers between them.
  237. Conv layer with strid > 1 is also considered as a Pooling layer.
  238. """
  239. layer_list = []
  240. node_list = [start_node_id]
  241. assert self._depth_first_search(end_node_id, layer_list, node_list)
  242. ret = []
  243. for layer_id in layer_list:
  244. layer = self.layer_list[layer_id]
  245. if is_layer(layer, "Pooling"):
  246. ret.append(layer)
  247. elif is_layer(layer, "Conv") and layer.stride != 1:
  248. ret.append(layer)
  249. return ret
  250. def _depth_first_search(self, target_id, layer_id_list, node_list):
  251. """Search for all the layers and nodes down the path.
  252. A recursive function to search all the layers and nodes between the node in the node_list
  253. and the node with target_id."""
  254. assert len(node_list) <= self.n_nodes
  255. u = node_list[-1]
  256. if u == target_id:
  257. return True
  258. for v, layer_id in self.adj_list[u]:
  259. layer_id_list.append(layer_id)
  260. node_list.append(v)
  261. if self._depth_first_search(target_id, layer_id_list, node_list):
  262. return True
  263. layer_id_list.pop()
  264. node_list.pop()
  265. return False
  266. def _search(self, u, start_dim, total_dim, n_add):
  267. """Search the graph for all the layers to be widened caused by an operation.
  268. It is an recursive function with duplication check to avoid deadlock.
  269. It searches from a starting node u until the corresponding layers has been widened.
  270. Args:
  271. u: The starting node ID.
  272. start_dim: The position to insert the additional dimensions.
  273. total_dim: The total number of dimensions the layer has before widening.
  274. n_add: The number of dimensions to add.
  275. """
  276. if (u, start_dim, total_dim, n_add) in self.vis:
  277. return
  278. self.vis[(u, start_dim, total_dim, n_add)] = True
  279. for v, layer_id in self.adj_list[u]:
  280. layer = self.layer_list[layer_id]
  281. if is_layer(layer, "Conv"):
  282. new_layer = wider_next_conv(
  283. layer, start_dim, total_dim, n_add, self.weighted
  284. )
  285. self._replace_layer(layer_id, new_layer)
  286. elif is_layer(layer, "Dense"):
  287. new_layer = wider_next_dense(
  288. layer, start_dim, total_dim, n_add, self.weighted
  289. )
  290. self._replace_layer(layer_id, new_layer)
  291. elif is_layer(layer, "BatchNormalization"):
  292. new_layer = wider_bn(
  293. layer, start_dim, total_dim, n_add, self.weighted)
  294. self._replace_layer(layer_id, new_layer)
  295. self._search(v, start_dim, total_dim, n_add)
  296. elif is_layer(layer, "Concatenate"):
  297. if self.layer_id_to_input_node_ids[layer_id][1] == u:
  298. # u is on the right of the concat
  299. # next_start_dim += next_total_dim - total_dim
  300. left_dim = self._upper_layer_width(
  301. self.layer_id_to_input_node_ids[layer_id][0]
  302. )
  303. next_start_dim = start_dim + left_dim
  304. next_total_dim = total_dim + left_dim
  305. else:
  306. next_start_dim = start_dim
  307. next_total_dim = total_dim + self._upper_layer_width(
  308. self.layer_id_to_input_node_ids[layer_id][1]
  309. )
  310. self._search(v, next_start_dim, next_total_dim, n_add)
  311. else:
  312. self._search(v, start_dim, total_dim, n_add)
  313. for v, layer_id in self.reverse_adj_list[u]:
  314. layer = self.layer_list[layer_id]
  315. if is_layer(layer, "Conv"):
  316. new_layer = wider_pre_conv(layer, n_add, self.weighted)
  317. self._replace_layer(layer_id, new_layer)
  318. elif is_layer(layer, "Dense"):
  319. new_layer = wider_pre_dense(layer, n_add, self.weighted)
  320. self._replace_layer(layer_id, new_layer)
  321. elif is_layer(layer, "Concatenate"):
  322. continue
  323. else:
  324. self._search(v, start_dim, total_dim, n_add)
  325. def _upper_layer_width(self, u):
  326. for v, layer_id in self.reverse_adj_list[u]:
  327. layer = self.layer_list[layer_id]
  328. if is_layer(layer, "Conv") or is_layer(layer, "Dense"):
  329. return layer_width(layer)
  330. elif is_layer(layer, "Concatenate"):
  331. a = self.layer_id_to_input_node_ids[layer_id][0]
  332. b = self.layer_id_to_input_node_ids[layer_id][1]
  333. return self._upper_layer_width(a) + self._upper_layer_width(b)
  334. else:
  335. return self._upper_layer_width(v)
  336. return self.node_list[0].shape[-1]
  337. def to_deeper_model(self, target_id, new_layer):
  338. """Insert a relu-conv-bn block after the target block.
  339. Args:
  340. target_id: A convolutional layer ID. The new block should be inserted after the block.
  341. new_layer: An instance of StubLayer subclasses.
  342. """
  343. self.operation_history.append(
  344. ("to_deeper_model", target_id, new_layer))
  345. input_id = self.layer_id_to_input_node_ids[target_id][0]
  346. output_id = self.layer_id_to_output_node_ids[target_id][0]
  347. if self.weighted:
  348. if is_layer(new_layer, "Dense"):
  349. init_dense_weight(new_layer)
  350. elif is_layer(new_layer, "Conv"):
  351. init_conv_weight(new_layer)
  352. elif is_layer(new_layer, "BatchNormalization"):
  353. init_bn_weight(new_layer)
  354. self._insert_new_layers([new_layer], input_id, output_id)
  355. def to_wider_model(self, pre_layer_id, n_add):
  356. """Widen the last dimension of the output of the pre_layer.
  357. Args:
  358. pre_layer_id: The ID of a convolutional layer or dense layer.
  359. n_add: The number of dimensions to add.
  360. """
  361. self.operation_history.append(("to_wider_model", pre_layer_id, n_add))
  362. pre_layer = self.layer_list[pre_layer_id]
  363. output_id = self.layer_id_to_output_node_ids[pre_layer_id][0]
  364. dim = layer_width(pre_layer)
  365. self.vis = {}
  366. self._search(output_id, dim, dim, n_add)
  367. # Update the tensor shapes.
  368. for u in self.topological_order:
  369. for v, layer_id in self.adj_list[u]:
  370. self.node_list[v].shape = self.layer_list[layer_id].output_shape
  371. def _insert_new_layers(self, new_layers, start_node_id, end_node_id):
  372. """Insert the new_layers after the node with start_node_id."""
  373. new_node_id = self._add_node(deepcopy(self.node_list[end_node_id]))
  374. temp_output_id = new_node_id
  375. for layer in new_layers[:-1]:
  376. temp_output_id = self.add_layer(layer, temp_output_id)
  377. self._add_edge(new_layers[-1], temp_output_id, end_node_id)
  378. new_layers[-1].input = self.node_list[temp_output_id]
  379. new_layers[-1].output = self.node_list[end_node_id]
  380. self._redirect_edge(start_node_id, end_node_id, new_node_id)
  381. def _block_end_node(self, layer_id, block_size):
  382. ret = self.layer_id_to_output_node_ids[layer_id][0]
  383. for _ in range(block_size - 2):
  384. ret = self.adj_list[ret][0][0]
  385. return ret
  386. def _dense_block_end_node(self, layer_id):
  387. return self.layer_id_to_input_node_ids[layer_id][0]
  388. def _conv_block_end_node(self, layer_id):
  389. """Get the input node ID of the last layer in the block by layer ID.
  390. Return the input node ID of the last layer in the convolutional block.
  391. Args:
  392. layer_id: the convolutional layer ID.
  393. """
  394. return self._block_end_node(layer_id, Constant.CONV_BLOCK_DISTANCE)
  395. def to_add_skip_model(self, start_id, end_id):
  396. """Add a weighted add skip-connection from after start node to end node.
  397. Args:
  398. start_id: The convolutional layer ID, after which to start the skip-connection.
  399. end_id: The convolutional layer ID, after which to end the skip-connection.
  400. """
  401. self.operation_history.append(("to_add_skip_model", start_id, end_id))
  402. filters_end = self.layer_list[end_id].output.shape[-1]
  403. filters_start = self.layer_list[start_id].output.shape[-1]
  404. start_node_id = self.layer_id_to_output_node_ids[start_id][0]
  405. pre_end_node_id = self.layer_id_to_input_node_ids[end_id][0]
  406. end_node_id = self.layer_id_to_output_node_ids[end_id][0]
  407. skip_output_id = self._insert_pooling_layer_chain(
  408. start_node_id, end_node_id)
  409. # Add the conv layer in order to align the number of channels with end layer id
  410. new_conv_layer = get_conv_class(
  411. self.n_dim)(
  412. filters_start,
  413. filters_end,
  414. 1)
  415. skip_output_id = self.add_layer(new_conv_layer, skip_output_id)
  416. # Add the add layer.
  417. add_input_node_id = self._add_node(
  418. deepcopy(self.node_list[end_node_id]))
  419. add_layer = StubAdd()
  420. self._redirect_edge(pre_end_node_id, end_node_id, add_input_node_id)
  421. self._add_edge(add_layer, add_input_node_id, end_node_id)
  422. self._add_edge(add_layer, skip_output_id, end_node_id)
  423. add_layer.input = [
  424. self.node_list[add_input_node_id],
  425. self.node_list[skip_output_id],
  426. ]
  427. add_layer.output = self.node_list[end_node_id]
  428. self.node_list[end_node_id].shape = add_layer.output_shape
  429. # Set weights to the additional conv layer.
  430. if self.weighted:
  431. filter_shape = (1,) * self.n_dim
  432. weights = np.zeros((filters_end, filters_start) + filter_shape)
  433. bias = np.zeros(filters_end)
  434. new_conv_layer.set_weights(
  435. (add_noise(weights, np.array([0, 1])), add_noise(
  436. bias, np.array([0, 1])))
  437. )
  438. def to_concat_skip_model(self, start_id, end_id):
  439. """Add a weighted add concatenate connection from after start node to end node.
  440. Args:
  441. start_id: The convolutional layer ID, after which to start the skip-connection.
  442. end_id: The convolutional layer ID, after which to end the skip-connection.
  443. """
  444. self.operation_history.append(
  445. ("to_concat_skip_model", start_id, end_id))
  446. filters_end = self.layer_list[end_id].output.shape[-1]
  447. filters_start = self.layer_list[start_id].output.shape[-1]
  448. start_node_id = self.layer_id_to_output_node_ids[start_id][0]
  449. pre_end_node_id = self.layer_id_to_input_node_ids[end_id][0]
  450. end_node_id = self.layer_id_to_output_node_ids[end_id][0]
  451. skip_output_id = self._insert_pooling_layer_chain(
  452. start_node_id, end_node_id)
  453. concat_input_node_id = self._add_node(
  454. deepcopy(self.node_list[end_node_id]))
  455. self._redirect_edge(pre_end_node_id, end_node_id, concat_input_node_id)
  456. concat_layer = StubConcatenate()
  457. concat_layer.input = [
  458. self.node_list[concat_input_node_id],
  459. self.node_list[skip_output_id],
  460. ]
  461. concat_output_node_id = self._add_node(Node(concat_layer.output_shape))
  462. self._add_edge(
  463. concat_layer,
  464. concat_input_node_id,
  465. concat_output_node_id)
  466. self._add_edge(concat_layer, skip_output_id, concat_output_node_id)
  467. concat_layer.output = self.node_list[concat_output_node_id]
  468. self.node_list[concat_output_node_id].shape = concat_layer.output_shape
  469. # Add the concatenate layer.
  470. # concat过channel数增加,用conv class 回到原先的channel数
  471. new_conv_layer = get_conv_class(self.n_dim)(
  472. filters_start + filters_end, filters_end, 1
  473. )
  474. self._add_edge(new_conv_layer, concat_output_node_id, end_node_id)
  475. new_conv_layer.input = self.node_list[concat_output_node_id]
  476. new_conv_layer.output = self.node_list[end_node_id]
  477. self.node_list[end_node_id].shape = new_conv_layer.output_shape
  478. if self.weighted:
  479. filter_shape = (1,) * self.n_dim
  480. weights = np.zeros((filters_end, filters_end) + filter_shape)
  481. for i in range(filters_end):
  482. filter_weight = np.zeros((filters_end,) + filter_shape)
  483. center_index = (i,) + (0,) * self.n_dim
  484. filter_weight[center_index] = 1
  485. weights[i, ...] = filter_weight
  486. weights = np.concatenate(
  487. (weights, np.zeros((filters_end, filters_start) + filter_shape)), axis=1
  488. )
  489. bias = np.zeros(filters_end)
  490. new_conv_layer.set_weights(
  491. (add_noise(weights, np.array([0, 1])), add_noise(
  492. bias, np.array([0, 1])))
  493. )
  494. def _insert_pooling_layer_chain(self, start_node_id, end_node_id):
  495. """
  496. insert pooling layer
  497. """
  498. skip_output_id = start_node_id
  499. # 得到从start_node_id 到 end_node_id之间的所有pooling layer(包括conv layer stride > 1)
  500. for layer in self._get_pooling_layers(start_node_id, end_node_id):
  501. new_layer = deepcopy(layer)
  502. # 如果是conv层需要重新初始化weights
  503. if is_layer(new_layer, "Conv"):
  504. # start node id 的通道数
  505. filters = self.node_list[start_node_id].shape[-1]
  506. new_layer = get_conv_class(self.n_dim)(
  507. filters, filters, 1, layer.stride)
  508. if self.weighted:
  509. init_conv_weight(new_layer)
  510. else:
  511. new_layer = deepcopy(layer)
  512. skip_output_id = self.add_layer(new_layer, skip_output_id)
  513. skip_output_id = self.add_layer(StubReLU(), skip_output_id)
  514. return skip_output_id
  515. def extract_descriptor(self):
  516. """Extract the the description of the Graph as an instance of NetworkDescriptor."""
  517. main_chain = self.get_main_chain()
  518. index_in_main_chain = {}
  519. for index, u in enumerate(main_chain):
  520. index_in_main_chain[u] = index
  521. ret = NetworkDescriptor()
  522. for u in main_chain:
  523. for v, layer_id in self.adj_list[u]:
  524. if v not in index_in_main_chain:
  525. continue
  526. layer = self.layer_list[layer_id]
  527. copied_layer = copy(layer)
  528. copied_layer.weights = None
  529. ret.add_layer(deepcopy(copied_layer))
  530. for u in index_in_main_chain:
  531. for v, layer_id in self.adj_list[u]:
  532. if v not in index_in_main_chain:
  533. temp_u = u
  534. temp_v = v
  535. temp_layer_id = layer_id
  536. skip_type = None
  537. while not (
  538. temp_v in index_in_main_chain and temp_u in index_in_main_chain):
  539. if is_layer(
  540. self.layer_list[temp_layer_id], "Concatenate"):
  541. skip_type = NetworkDescriptor.CONCAT_CONNECT
  542. if is_layer(self.layer_list[temp_layer_id], "Add"):
  543. skip_type = NetworkDescriptor.ADD_CONNECT
  544. temp_u = temp_v
  545. temp_v, temp_layer_id = self.adj_list[temp_v][0]
  546. ret.add_skip_connection(
  547. index_in_main_chain[u], index_in_main_chain[temp_u], skip_type
  548. )
  549. elif index_in_main_chain[v] - index_in_main_chain[u] != 1:
  550. skip_type = None
  551. if is_layer(self.layer_list[layer_id], "Concatenate"):
  552. skip_type = NetworkDescriptor.CONCAT_CONNECT
  553. if is_layer(self.layer_list[layer_id], "Add"):
  554. skip_type = NetworkDescriptor.ADD_CONNECT
  555. ret.add_skip_connection(
  556. index_in_main_chain[u], index_in_main_chain[v], skip_type
  557. )
  558. return ret
  559. def clear_weights(self):
  560. ''' clear weights of the graph
  561. '''
  562. self.weighted = False
  563. for layer in self.layer_list:
  564. layer.weights = None
  565. def produce_torch_model(self):
  566. """Build a new Torch model based on the current graph."""
  567. return TorchModel(self)
  568. def produce_json_model(self):
  569. """Build a new Json model based on the current graph."""
  570. return JSONModel(self).data
  571. @classmethod
  572. def parsing_json_model(cls, json_model):
  573. '''build a graph from json
  574. '''
  575. return json_to_graph(json_model)
  576. def _layer_ids_in_order(self, layer_ids):
  577. node_id_to_order_index = {}
  578. for index, node_id in enumerate(self.topological_order):
  579. node_id_to_order_index[node_id] = index
  580. return sorted(
  581. layer_ids,
  582. key=lambda layer_id: node_id_to_order_index[
  583. self.layer_id_to_output_node_ids[layer_id][0]
  584. ],
  585. )
  586. def _layer_ids_by_type(self, type_str):
  587. return list(
  588. filter(
  589. lambda layer_id: is_layer(self.layer_list[layer_id], type_str),
  590. range(self.n_layers),
  591. )
  592. )
  593. def get_main_chain_layers(self):
  594. """Return a list of layer IDs in the main chain."""
  595. main_chain = self.get_main_chain()
  596. ret = []
  597. for u in main_chain:
  598. for v, layer_id in self.adj_list[u]:
  599. if v in main_chain and u in main_chain:
  600. ret.append(layer_id)
  601. return ret
  602. def _conv_layer_ids_in_order(self):
  603. return list(
  604. filter(
  605. lambda layer_id: is_layer(self.layer_list[layer_id], "Conv"),
  606. self.get_main_chain_layers(),
  607. )
  608. )
  609. def _dense_layer_ids_in_order(self):
  610. return self._layer_ids_in_order(self._layer_ids_by_type("Dense"))
  611. def deep_layer_ids(self):
  612. ret = []
  613. for layer_id in self.get_main_chain_layers():
  614. layer = self.layer_list[layer_id]
  615. # GAP之后就不插入layer了
  616. if is_layer(layer, "GlobalAveragePooling"):
  617. break
  618. if is_layer(layer, "Add") or is_layer(layer, "Concatenate"):
  619. continue
  620. ret.append(layer_id)
  621. return ret
  622. def wide_layer_ids(self):
  623. return (
  624. self._conv_layer_ids_in_order(
  625. )[:-1] + self._dense_layer_ids_in_order()[:-1]
  626. )
  627. def skip_connection_layer_ids(self):
  628. return self.deep_layer_ids()[:-1]
  629. def size(self):
  630. return sum(list(map(lambda x: x.size(), self.layer_list)))
  631. def get_main_chain(self):
  632. """Returns the main chain node ID list."""
  633. pre_node = {}
  634. distance = {}
  635. # 初始化每个节点距离为0,他的前一个节点为自己
  636. for i in range(self.n_nodes):
  637. distance[i] = 0
  638. pre_node[i] = i
  639. # 遍历所有节点,根据邻接表找到他的前一个节点以及他本身的位置
  640. for i in range(self.n_nodes - 1):
  641. for u in range(self.n_nodes):
  642. for v, _ in self.adj_list[u]:
  643. if distance[u] + 1 > distance[v]:
  644. distance[v] = distance[u] + 1
  645. pre_node[v] = u
  646. # temp_id记录距离最大的node
  647. temp_id = 0
  648. for i in range(self.n_nodes):
  649. if distance[i] > distance[temp_id]:
  650. temp_id = i
  651. # 从距离最大的node开始不断找到他的前一个节点,最终找到主链
  652. ret = []
  653. for i in range(self.n_nodes + 5):
  654. ret.append(temp_id)
  655. if pre_node[temp_id] == temp_id:
  656. break
  657. temp_id = pre_node[temp_id]
  658. assert temp_id == pre_node[temp_id]
  659. ret.reverse()
  660. return ret
  661. class TorchModel(torch.nn.Module):
  662. """A neural network class using pytorch constructed from an instance of Graph."""
  663. def __init__(self, graph):
  664. super(TorchModel, self).__init__()
  665. self.graph = graph
  666. self.layers = torch.nn.ModuleList()
  667. for layer in graph.layer_list:
  668. self.layers.append(layer.to_real_layer())
  669. if graph.weighted:
  670. for index, layer in enumerate(self.layers):
  671. set_stub_weight_to_torch(self.graph.layer_list[index], layer)
  672. for index, layer in enumerate(self.layers):
  673. self.add_module(str(index), layer)
  674. def forward(self, input_tensor):
  675. topo_node_list = self.graph.topological_order
  676. output_id = topo_node_list[-1]
  677. input_id = topo_node_list[0]
  678. node_list = deepcopy(self.graph.node_list)
  679. node_list[input_id] = input_tensor
  680. for v in topo_node_list:
  681. for u, layer_id in self.graph.reverse_adj_list[v]:
  682. layer = self.graph.layer_list[layer_id]
  683. torch_layer = self.layers[layer_id]
  684. if isinstance(layer, (StubAdd, StubConcatenate)):
  685. edge_input_tensor = list(
  686. map(
  687. lambda x: node_list[x],
  688. self.graph.layer_id_to_input_node_ids[layer_id],
  689. )
  690. )
  691. else:
  692. edge_input_tensor = node_list[u]
  693. temp_tensor = torch_layer(edge_input_tensor)
  694. node_list[v] = temp_tensor
  695. return node_list[output_id]
  696. def set_weight_to_graph(self):
  697. self.graph.weighted = True
  698. for index, layer in enumerate(self.layers):
  699. set_torch_weight_to_stub(layer, self.graph.layer_list[index])
  700. class JSONModel:
  701. def __init__(self, graph):
  702. data = dict()
  703. node_list = list()
  704. layer_list = list()
  705. operation_history = list()
  706. data["input_shape"] = graph.input_shape
  707. vis = graph.vis
  708. data["vis"] = list(vis.keys()) if vis is not None else None
  709. data["weighted"] = graph.weighted
  710. for item in graph.operation_history:
  711. if item[0] == "to_deeper_model":
  712. operation_history.append(
  713. [
  714. item[0],
  715. item[1],
  716. layer_description_extractor(item[2], graph.node_to_id),
  717. ]
  718. )
  719. else:
  720. operation_history.append(item)
  721. data["operation_history"] = operation_history
  722. data["layer_id_to_input_node_ids"] = graph.layer_id_to_input_node_ids
  723. data["layer_id_to_output_node_ids"] = graph.layer_id_to_output_node_ids
  724. data["adj_list"] = graph.adj_list
  725. data["reverse_adj_list"] = graph.reverse_adj_list
  726. for node in graph.node_list:
  727. node_id = graph.node_to_id[node]
  728. node_information = node.shape
  729. node_list.append((node_id, node_information))
  730. for layer_id, item in enumerate(graph.layer_list):
  731. layer = graph.layer_list[layer_id]
  732. layer_information = layer_description_extractor(
  733. layer, graph.node_to_id)
  734. layer_list.append((layer_id, layer_information))
  735. data["node_list"] = node_list
  736. data["layer_list"] = layer_list
  737. self.data = data
  738. def graph_to_json(graph, json_model_path):
  739. json_out = graph.produce_json_model()
  740. with open(json_model_path, "w") as fout:
  741. json.dump(json_out, fout)
  742. json_out = json.dumps(json_out)
  743. return json_out
  744. def json_to_graph(json_model: str):
  745. json_model = json.loads(json_model)
  746. # restore graph data from json data
  747. input_shape = tuple(json_model["input_shape"])
  748. node_list = list()
  749. node_to_id = dict()
  750. id_to_node = dict()
  751. layer_list = list()
  752. layer_to_id = dict()
  753. operation_history = list()
  754. graph = Graph(input_shape, False)
  755. graph.input_shape = input_shape
  756. vis = json_model["vis"]
  757. graph.vis = {
  758. tuple(item): True for item in vis} if vis is not None else None
  759. graph.weighted = json_model["weighted"]
  760. layer_id_to_input_node_ids = json_model["layer_id_to_input_node_ids"]
  761. graph.layer_id_to_input_node_ids = {
  762. int(k): v for k, v in layer_id_to_input_node_ids.items()
  763. }
  764. layer_id_to_output_node_ids = json_model["layer_id_to_output_node_ids"]
  765. graph.layer_id_to_output_node_ids = {
  766. int(k): v for k, v in layer_id_to_output_node_ids.items()
  767. }
  768. adj_list = {}
  769. for k, v in json_model["adj_list"].items():
  770. adj_list[int(k)] = [tuple(i) for i in v]
  771. graph.adj_list = adj_list
  772. reverse_adj_list = {}
  773. for k, v in json_model["reverse_adj_list"].items():
  774. reverse_adj_list[int(k)] = [tuple(i) for i in v]
  775. graph.reverse_adj_list = reverse_adj_list
  776. for item in json_model["node_list"]:
  777. new_node = Node(tuple(item[1]))
  778. node_id = item[0]
  779. node_list.append(new_node)
  780. node_to_id[new_node] = node_id
  781. id_to_node[node_id] = new_node
  782. for item in json_model["operation_history"]:
  783. if item[0] == "to_deeper_model":
  784. operation_history.append(
  785. (item[0], item[1], layer_description_builder(item[2], id_to_node))
  786. )
  787. else:
  788. operation_history.append(item)
  789. graph.operation_history = operation_history
  790. for item in json_model["layer_list"]:
  791. new_layer = layer_description_builder(item[1], id_to_node)
  792. layer_id = int(item[0])
  793. layer_list.append(new_layer)
  794. layer_to_id[new_layer] = layer_id
  795. graph.node_list = node_list
  796. graph.node_to_id = node_to_id
  797. graph.layer_list = layer_list
  798. graph.layer_to_id = layer_to_id
  799. return graph

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能