You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

attribution_graph.py 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. """
  2. Copyright 2020 Tianshu AI Platform. All Rights Reserved.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. =============================================================
  13. """
  14. from typing import Type, Dict, Any
  15. import itertools
  16. import numpy as np
  17. from numpy import Inf
  18. import networkx
  19. import torch
  20. from torch.nn import Module
  21. from torch.utils.data import Dataset
  22. from .attribution_map import attribution_map, attr_map_similarity
  23. def graph_to_array(graph: networkx.Graph):
  24. weight_matrix = np.zeros((len(graph.nodes), len(graph.nodes)))
  25. for i, n1 in enumerate(graph.nodes):
  26. for j, n2 in enumerate(graph.nodes):
  27. try:
  28. dist = graph[n1][n2]["weight"]
  29. except KeyError:
  30. dist = 1
  31. weight_matrix[i, j] = dist
  32. return weight_matrix
  33. class FeatureMapExtractor():
  34. def __init__(self, module: Module):
  35. self.module = module
  36. self.feature_pool: Dict[str, Dict[str, Any]] = dict()
  37. self.register_hooks()
  38. def register_hooks(self):
  39. for name, m in self.module.named_modules():
  40. if "pool" in name:
  41. m.name = name
  42. self.feature_pool[name] = dict()
  43. def hook(m: Module, input, output):
  44. self.feature_pool[m.name]["feature"] = input
  45. self.feature_pool[name]["handle"] = m.register_forward_hook(hook)
  46. def _forward(self, x):
  47. self.module(x)
  48. def remove_hooks(self):
  49. for name, cfg in self.feature_pool.items():
  50. cfg["handle"].remove()
  51. cfg.clear()
  52. self.feature_pool.clear()
  53. def extract_final_map(self, x):
  54. self._forward(x)
  55. feature_map = None
  56. max_channel = 0
  57. min_size = Inf
  58. for name, cfg in self.feature_pool.items():
  59. f = cfg["feature"]
  60. if len(f) == 1 and isinstance(f[0], torch.Tensor):
  61. f = f[0]
  62. if f.dim() == 4: # BxCxHxW
  63. b, c, h, w = f.shape
  64. if c >= max_channel and 1 < h * w <= min_size:
  65. feature_map = f
  66. max_channel = c
  67. min_size = h * w
  68. return feature_map
  69. def get_attribution_graph(
  70. model: Module,
  71. attribution_type: Type,
  72. with_noise: bool,
  73. probe_data: Dataset,
  74. device: torch.device,
  75. norm_square: bool = False,
  76. ):
  77. attribution_graph = networkx.Graph()
  78. model = model.to(device)
  79. extractor = FeatureMapExtractor(model)
  80. for i, x in enumerate(probe_data):
  81. x = x.to(device)
  82. x.requires_grad_()
  83. attribution = attribution_map(
  84. func=lambda x: extractor.extract_final_map(x),
  85. attribution_type=attribution_type,
  86. with_noise=with_noise,
  87. probe_data=x.unsqueeze(0),
  88. norm_square=norm_square
  89. )
  90. attribution_graph.add_node(i, attribution_map=attribution)
  91. nodes = attribution_graph.nodes
  92. for i, j in itertools.product(nodes, nodes):
  93. if i < j:
  94. weight = attr_map_similarity(
  95. attribution_graph.nodes(data=True)[i]["attribution_map"],
  96. attribution_graph.nodes(data=True)[j]["attribution_map"]
  97. )
  98. attribution_graph.add_edge(i, j, weight=weight)
  99. return attribution_graph
  100. def edge_to_embedding(graph: networkx.Graph):
  101. adj = graph_to_array(graph)
  102. up_tri_mask = np.tri(*adj.shape[-2:], k=0, dtype=bool)
  103. return adj[up_tri_mask]
  104. def embedding_to_rank(embedding: np.ndarray):
  105. order = embedding.argsort()
  106. ranks = order.argsort()
  107. return ranks
  108. def graph_similarity(g1: networkx.Graph, g2: networkx.Graph, Lambda: float = 1.0):
  109. nodes_1 = g1.nodes(data=True)
  110. nodes_2 = g2.nodes(data=True)
  111. assert len(nodes_1) == len(nodes_2)
  112. # calculate vertex similarity
  113. v_s = 0
  114. n = len(g1.nodes)
  115. for i in range(n):
  116. v_s += attr_map_similarity(
  117. map_1=g1.nodes(data=True)[i]["attribution_map"],
  118. map_2=g2.nodes(data=True)[i]["attribution_map"]
  119. )
  120. vertex_similarity = v_s / n
  121. # calculate edges similarity
  122. emb_1 = edge_to_embedding(g1)
  123. emb_2 = edge_to_embedding(g2)
  124. rank_1 = embedding_to_rank(emb_1)
  125. rank_2 = embedding_to_rank(emb_2)
  126. k = emb_1.shape[0]
  127. edge_similarity = 1 - 6 * np.sum(np.square(rank_1 - rank_2)) / (k ** 3 - k)
  128. return vertex_similarity + Lambda * edge_similarity
  129. if __name__ == "__main__":
  130. from captum.attr import InputXGradient
  131. from torchvision.models import resnet34
  132. model_1 = resnet34(num_classes=10)
  133. graph_1 = get_attribution_graph(
  134. model_1,
  135. attribution_type=InputXGradient,
  136. with_noise=False,
  137. probe_data=torch.rand(10, 3, 244, 244),
  138. device=torch.device("cpu")
  139. )
  140. model_2 = resnet34(num_classes=10)
  141. graph_2 = get_attribution_graph(
  142. model_2,
  143. attribution_type=InputXGradient,
  144. with_noise=False,
  145. probe_data=torch.rand(10, 3, 244, 244),
  146. device=torch.device("cpu")
  147. )
  148. print(graph_similarity(graph_1, graph_2))

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能

Contributors (1)