You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 3.0 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. from enum import Enum
  2. import json
  3. class Constant:
  4. # Data
  5. CUTOUT_HOLES = 1
  6. CUTOUT_RATIO = 0.5
  7. # Searcher
  8. MAX_MODEL_NUM = 1000
  9. MAX_LAYERS = 200
  10. N_NEIGHBOURS = 8
  11. MAX_MODEL_SIZE = (1 << 25)
  12. MAX_LAYER_WIDTH = 4096
  13. KERNEL_LAMBDA = 1.0
  14. BETA = 2.576
  15. T_MIN = 0.0001
  16. MLP_MODEL_LEN = 3
  17. MLP_MODEL_WIDTH = 5
  18. MODEL_LEN = 3
  19. MODEL_WIDTH = 64
  20. POOLING_KERNEL_SIZE = 2
  21. DENSE_DROPOUT_RATE = 0.5
  22. CONV_DROPOUT_RATE = 0.25
  23. MLP_DROPOUT_RATE = 0.25
  24. CONV_BLOCK_DISTANCE = 2
  25. # trainer
  26. MAX_NO_IMPROVEMENT_NUM = 5
  27. MIN_LOSS_DEC = 1e-4
  28. class OptimizeMode(Enum):
  29. """Optimize Mode class
  30. if OptimizeMode is 'minimize', it means the tuner need to minimize the reward
  31. that received from Trial.
  32. if OptimizeMode is 'maximize', it means the tuner need to maximize the reward
  33. that received from Trial.
  34. """
  35. Minimize = 'minimize'
  36. Maximize = 'maximize'
  37. class EarlyStop:
  38. """A class check for early stop condition.
  39. Attributes:
  40. training_losses: Record all the training loss.
  41. minimum_loss: The minimum loss we achieve so far. Used to compared to determine no improvement condition.
  42. no_improvement_count: Current no improvement count.
  43. _max_no_improvement_num: The maximum number specified.
  44. _done: Whether condition met.
  45. _min_loss_dec: A threshold for loss improvement.
  46. """
  47. def __init__(self, max_no_improvement_num=None, min_loss_dec=None):
  48. self.training_losses = []
  49. self.minimum_loss = None
  50. self.no_improvement_count = 0
  51. self._max_no_improvement_num = max_no_improvement_num if max_no_improvement_num is not None \
  52. else Constant.MAX_NO_IMPROVEMENT_NUM
  53. self._done = False
  54. self._min_loss_dec = min_loss_dec if min_loss_dec is not None else Constant.MIN_LOSS_DEC
  55. def on_train_begin(self):
  56. """Initiate the early stop condition.
  57. Call on every time the training iteration begins.
  58. """
  59. self.training_losses = []
  60. self.no_improvement_count = 0
  61. self._done = False
  62. self.minimum_loss = float('inf')
  63. def on_epoch_end(self, loss):
  64. """Check the early stop condition.
  65. Call on every time the training iteration end.
  66. Args:
  67. loss: The loss function achieved by the epoch.
  68. Returns:
  69. True if condition met, otherwise False.
  70. """
  71. self.training_losses.append(loss)
  72. if self._done and loss > (self.minimum_loss - self._min_loss_dec):
  73. return False
  74. if loss > (self.minimum_loss - self._min_loss_dec):
  75. self.no_improvement_count += 1
  76. else:
  77. self.no_improvement_count = 0
  78. self.minimum_loss = loss
  79. if self.no_improvement_count > self._max_no_improvement_num:
  80. self._done = True
  81. return True
  82. def save_json_result(path, data):
  83. with open(path,'a') as f:
  84. json.dump(data,f)
  85. f.write('\n')

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能