You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

evolution_tuner.py 15 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. import os
  2. import re
  3. import sys
  4. import json
  5. import pickle
  6. import logging
  7. import subprocess
  8. import numpy as np
  9. from collections import deque
  10. from evaluator import Evaluator
  11. from network import ShuffleNetV2OneShot, PARSED_FLOPS
  12. LAYER_CHOICE = "layer_choice"
  13. INPUT_CHOICE = "input_choice"
  14. _logger = logging.getLogger(__name__)
  15. class SPOSEvolution:
  16. """
  17. SPOS evolution tuner.
  18. Parameters
  19. ----------
  20. max_epochs : int
  21. Maximum number of epochs to run.
  22. num_select : int
  23. Number of survival candidates of each epoch.
  24. num_population : int
  25. Number of candidates at the start of each epoch. If candidates generated by
  26. crossover and mutation are not enough, the rest will be filled with random
  27. candidates.
  28. m_prob : float
  29. The probability of mutation.
  30. num_crossover : int
  31. Number of candidates generated by crossover in each epoch.
  32. num_mutation : int
  33. Number of candidates generated by mutation in each epoch.
  34. """
  35. def __init__(self, max_epochs=20, num_select=10, num_population=50, m_prob=0.1,
  36. num_crossover=25, num_mutation=25, epoch=0):
  37. assert num_population >= num_select
  38. self.max_epochs = max_epochs
  39. self.num_select = num_select
  40. self.num_population = num_population
  41. self.m_prob = m_prob
  42. self.num_crossover = num_crossover
  43. self.num_mutation = num_mutation
  44. self.epoch = epoch
  45. self.search_space = None
  46. self.random_state = np.random.RandomState(0)
  47. # self.evl = Evaluator()
  48. # async status
  49. self._to_evaluate_queue = deque()
  50. self._sending_parameter_queue = deque()
  51. self._pending_result_ids = set()
  52. self._reward_dict = dict()
  53. self._id2candidate = dict()
  54. self._st_callback = None
  55. self.cand_path = "./checkpoints"
  56. self.acc_path = "./acc"
  57. self.candidates = [] if epoch == 0 else self.load_candidates() # 第一轮初始尚未有生成的种群
  58. def load_candidates(self):
  59. # 从self.export_result()写入文件的候选模型,需要读入
  60. # {"LayerChoice1": [false, false, false, true], ... } -> {"LayerChoice1": {"_idx":3, "_value":"3"}, ... }
  61. print("## evolution -- load ## begin to load candidates in evolution...\n")
  62. file_dir, _, files = next(os.walk(self.cand_path))
  63. files = [i for i in files if "%03d_"%(self.epoch-1) in i]
  64. def get_true_index(l):
  65. return [i for i in range(len(l)) if l[i]][0]
  66. candidates = []
  67. for file in files:
  68. with open(os.path.join(file_dir, file), "r") as f:
  69. candidate = json.load(f)
  70. # 转换成合适的形式
  71. cand = {}
  72. for key, value in candidate.items():
  73. v = get_true_index(value)
  74. value = {"_value":str(v), "_idx":int(v)}
  75. cand.update({key:value})
  76. candidates.append(cand)
  77. print("## evolution -- load ## candidates loaded \n")
  78. return candidates
  79. def load_id2candidate(self):
  80. with open("./id2cand/%03d_id2candidate.json"%(self.epoch - 1), "r") as f:
  81. self.id2candidate = json.load(f)
  82. def update_search_space(self, search_space):
  83. """
  84. Handle the initialization/update event of search space.
  85. """
  86. print("## evolution -- update ## updating search space")
  87. self._search_space = search_space
  88. self._next_round()
  89. print("## evolution -- update ## search space updated")
  90. def _next_round(self):
  91. _logger.info("Epoch %d, generating...", self.epoch)
  92. if self.epoch == 0:
  93. self._get_random_population()
  94. self.export_results(self.candidates)
  95. self.evaluate_cands() # 评估全部的模型
  96. else:
  97. self.load_id2candidate()
  98. self.receive_trial_result()
  99. best_candidates = self._select_top_candidates()
  100. if self.epoch >= self.max_epochs:
  101. return
  102. self.candidates = self._get_mutation(best_candidates) + self._get_crossover(best_candidates)
  103. self._get_random_population()
  104. self.export_results(self.candidates)
  105. self.evaluate_cands() # 评估全部的模型
  106. self.epoch += 1
  107. def _random_candidate(self):
  108. chosen_arch = dict()
  109. for key, val in self._search_space.items():
  110. if val["_type"] == LAYER_CHOICE:
  111. choices = val["_value"]
  112. index = self.random_state.randint(len(choices))
  113. chosen_arch[key] = {"_value": choices[index], "_idx": index}
  114. elif val["_type"] == INPUT_CHOICE:
  115. raise NotImplementedError("Input choice is not implemented yet.")
  116. return chosen_arch
  117. def _add_to_evaluate_queue(self, cand):
  118. _logger.info("Generate candidate %s, adding to eval queue.", self._get_architecture_repr(cand))
  119. self._reward_dict[self._hashcode(cand)] = 0.
  120. self._to_evaluate_queue.append(cand)
  121. def _get_random_population(self):
  122. while len(self.candidates) < self.num_population:
  123. cand = self._random_candidate()
  124. if self._is_legal(cand):
  125. _logger.info("Random candidate generated.")
  126. self._add_to_evaluate_queue(cand)
  127. self.candidates.append(cand)
  128. def _get_crossover(self, best):
  129. result = []
  130. for _ in range(10 * self.num_crossover):
  131. cand_p1 = best[self.random_state.randint(len(best))]
  132. cand_p2 = best[self.random_state.randint(len(best))]
  133. assert cand_p1.keys() == cand_p2.keys()
  134. cand = {k: cand_p1[k] if self.random_state.randint(2) == 0 else cand_p2[k]
  135. for k in cand_p1.keys()}
  136. if self._is_legal(cand):
  137. result.append(cand)
  138. self._add_to_evaluate_queue(cand)
  139. if len(result) >= self.num_crossover:
  140. break
  141. _logger.info("Found %d architectures with crossover.", len(result))
  142. return result
  143. def _get_mutation(self, best):
  144. result = []
  145. for _ in range(10 * self.num_mutation):
  146. cand = best[self.random_state.randint(len(best))].copy()
  147. mutation_sample = np.random.random_sample(len(cand))
  148. for s, k in zip(mutation_sample, cand):
  149. if s < self.m_prob:
  150. choices = self._search_space[k]["_value"]
  151. index = self.random_state.randint(len(choices))
  152. cand[k] = {"_value": choices[index], "_idx": index}
  153. if self._is_legal(cand):
  154. result.append(cand)
  155. self._add_to_evaluate_queue(cand)
  156. if len(result) >= self.num_mutation:
  157. break
  158. _logger.info("Found %d architectures with mutation.", len(result))
  159. return result
  160. def _get_architecture_repr(self, cand):
  161. return re.sub(r"\".*?\": \{\"_idx\": (\d+), \"_value\": \".*?\"\}", r"\1",
  162. self._hashcode(cand))
  163. def _is_legal(self, cand):
  164. if self._hashcode(cand) in self._reward_dict:
  165. return False
  166. return True
  167. # 将模型输出,并重训练、评估
  168. def evaluate_cands(self):
  169. """
  170. 1、对输出的模型进行重训练
  171. 2、对重训练后的模型进行评估
  172. 以上内容通过tester.py脚本完成
  173. """
  174. print("## evolution -- evaluate ## begin to evaluate candidates...")
  175. file_dir, _, files = next(os.walk(self.cand_path)) # 获取文件夹下的文件
  176. files = [i for i in files if "%03d_"%self.epoch in i]
  177. for file in files:
  178. file = os.path.join(file_dir, file)
  179. # self.evl.eval_model(epoch=self.epoch, architecture=file)
  180. python_interpreter_path = sys.executable
  181. subprocess.run([python_interpreter_path,\
  182. "evaluator.py", "--architecture", file, "--epoch", str(self.epoch)])
  183. print("## evolution -- evaluate ## candidates evaluated")
  184. def _select_top_candidates(self):
  185. print("## evolution -- select ## begin to select top candidates...")
  186. reward_query = lambda cand: self._reward_dict[self._hashcode(cand)]
  187. _logger.info("All candidate rewards: %s", list(map(reward_query, self.candidates)))
  188. result = sorted(self.candidates, key=reward_query, reverse=True)[:self.num_select]
  189. _logger.info("Best candidate rewards: %s", list(map(reward_query, result)))
  190. print("## evolution -- select ## selected done")
  191. return result
  192. @staticmethod
  193. def _hashcode(d):
  194. return json.dumps(d, sort_keys=True)
  195. def _bind_and_send_parameters(self):
  196. """
  197. There are two types of resources: parameter ids and candidates. This function is called at
  198. necessary times to bind these resources to send new trials with st_callback.
  199. """
  200. result = []
  201. while self._sending_parameter_queue and self._to_evaluate_queue:
  202. parameter_id = self._sending_parameter_queue.popleft()
  203. parameters = self._to_evaluate_queue.popleft()
  204. self._id2candidate[parameter_id] = parameters
  205. result.append(parameters)
  206. self._pending_result_ids.add(parameter_id)
  207. self._st_callback(parameter_id, parameters)
  208. _logger.info("Send parameter [%d] %s.", parameter_id, self._get_architecture_repr(parameters))
  209. return result
  210. def generate_multiple_parameters(self, parameter_id_list, **kwargs):
  211. """
  212. Callback function necessary to implement a tuner. This will put more parameter ids into the
  213. parameter id queue.
  214. """
  215. if "st_callback" in kwargs and self._st_callback is None:
  216. self._st_callback = kwargs["st_callback"]
  217. for parameter_id in parameter_id_list:
  218. self._sending_parameter_queue.append(parameter_id)
  219. self._bind_and_send_parameters()
  220. return [] # always not use this. might induce problem of over-sending
  221. # def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
  222. # """
  223. # Callback function. Receive a trial result.
  224. # """
  225. # _logger.info("Candidate %d, reported reward %f", parameter_id, value)
  226. # self._reward_dict[self._hashcode(self._id2candidate[parameter_id])] = value
  227. def receive_trial_result(self):
  228. # 获取并更新self._reward_dict
  229. file_dir, _, files = next(os.walk(self.acc_path))
  230. files = [i for i in files if "%03d_"%(self.epoch-1) in i] # self.epoch-1: 读取上一轮的结果
  231. acc_dict = {}
  232. for file in files:
  233. with open(os.path.join(file_dir, file), "r") as f:
  234. acc_dict.update(json.load(f)) # {"000_001.json":0.56}
  235. for key, value in acc_dict.items():
  236. key = key.lstrip("./checkpoints/") # 删掉路径,仅保留文件名
  237. self._reward_dict.update({self.id2candidate[key]: value}) # todo {self.id2candidate[key]: key}
  238. def trial_end(self, parameter_id, success, **kwargs):
  239. """
  240. Callback function when a trial is ended and resource is released.
  241. """
  242. self._pending_result_ids.remove(parameter_id)
  243. if not self._pending_result_ids and not self._to_evaluate_queue:
  244. # a new epoch now
  245. self._next_round()
  246. assert self._st_callback is not None
  247. self._bind_and_send_parameters()
  248. def export_results(self, result):
  249. """
  250. Export a number of candidates to `checkpoints` dir.
  251. Parameters
  252. ----------
  253. result : dict
  254. Chosen architectures to be exported.
  255. """
  256. os.makedirs("checkpoints", exist_ok=True)
  257. os.makedirs("id2cand", exist_ok=True)
  258. self.id2candidate = {}
  259. for i, cand in enumerate(result):
  260. converted = dict()
  261. for cand_key, cand_val in cand.items():
  262. onehot = [k == cand_val["_idx"] for k in range(len(self._search_space[cand_key]["_value"]))]
  263. converted[cand_key] = onehot
  264. with open(os.path.join("checkpoints", "%03d_%03d.json" % (self.epoch, i)), "w") as fp:
  265. json.dump(converted, fp)
  266. """
  267. self.id2candidate:
  268. {
  269. 000_000.json: {"LayerChoice1": {"_values":3, "_idx":3}, "LayerChoice2": {"_values":2, "_idx":2}, ...}
  270. ......
  271. }
  272. """
  273. self.id2candidate.update({"%03d_%03d.json" % (self.epoch, i): json.dumps(result[i], sort_keys=True)})
  274. with open("./id2cand/%03d_id2candidate.json"%self.epoch, "w") as f:
  275. json.dump(self.id2candidate, f)
  276. class EvolutionWithFlops(SPOSEvolution):
  277. """
  278. This tuner extends the function of evolution tuner, by limiting the flops generated by tuner.
  279. Needs a function to examine the flops.
  280. """
  281. def __init__(self, flops_limit=330E6, **kwargs):
  282. super().__init__(**kwargs)
  283. # self.model = ShuffleNetV2OneShot()
  284. self.flops_limit = flops_limit
  285. with open(os.path.join(os.path.dirname(__file__), "./data/op_flops_dict.pkl"), "rb") as fp:
  286. self._op_flops_dict = pickle.load(fp)
  287. def _is_legal(self, cand):
  288. if not super()._is_legal(cand):
  289. return False
  290. if self.get_candidate_flops(cand) > self.flops_limit:
  291. return False
  292. return True
  293. def get_candidate_flops(self, candidate):
  294. """
  295. this method is the same with ShuffleNetV2OneShot.get_candidate_flops, but we dont need to initialize that class.
  296. """
  297. conv1_flops = self._op_flops_dict["conv1"][(3, 16,
  298. 224, 224, 2)]
  299. rest_flops = self._op_flops_dict["rest_operation"][(640, 1000,
  300. 7, 7, 1)]
  301. total_flops = conv1_flops + rest_flops
  302. for k, m in candidate.items():
  303. parsed_flops_dict = PARSED_FLOPS[k]
  304. if isinstance(m, dict): # to be compatible with classical nas format
  305. total_flops += parsed_flops_dict[m["_idx"]]
  306. else:
  307. total_flops += parsed_flops_dict[torch.max(m, 0)[1]]
  308. return total_flops

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能