You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

logger.py 7.3 kB


  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import contextlib
  10. import logging
  11. import os
  12. import sys
  13. _all_loggers = []
  14. _default_level_name = os.getenv("MEGENGINE_LOGGING_LEVEL", "INFO")
  15. _default_level = logging.getLevelName(_default_level_name.upper())
  16. def set_log_file(fout, mode="a"):
  17. r"""Sets log output file.
  18. Args:
  19. fout: file-like object that supports write and flush, or string for the filename
  20. mode: specify the mode to open log file if *fout* is a string
  21. """
  22. if isinstance(fout, str):
  23. fout = open(fout, mode)
  24. MegEngineLogFormatter.log_fout = fout
  25. class MegEngineLogFormatter(logging.Formatter):
  26. log_fout = None
  27. date_full = "[%(asctime)s %(lineno)d@%(filename)s:%(name)s] "
  28. date = "%(asctime)s "
  29. msg = "%(message)s"
  30. max_lines = 256
  31. def _color_exc(self, msg):
  32. r"""Sets the color of message as the execution type."""
  33. return "\x1b[34m{}\x1b[0m".format(msg)
  34. def _color_dbg(self, msg):
  35. r"""Sets the color of message as the debugging type."""
  36. return "\x1b[36m{}\x1b[0m".format(msg)
  37. def _color_warn(self, msg):
  38. r"""Sets the color of message as the warning type."""
  39. return "\x1b[1;31m{}\x1b[0m".format(msg)
  40. def _color_err(self, msg):
  41. r"""Sets the color of message as the error type."""
  42. return "\x1b[1;4;31m{}\x1b[0m".format(msg)
  43. def _color_omitted(self, msg):
  44. r"""Sets the color of message as the omitted type."""
  45. return "\x1b[35m{}\x1b[0m".format(msg)
  46. def _color_normal(self, msg):
  47. r"""Sets the color of message as the normal type."""
  48. return msg
  49. def _color_date(self, msg):
  50. r"""Sets the color of message the same as date."""
  51. return "\x1b[32m{}\x1b[0m".format(msg)
  52. def format(self, record):
  53. if record.levelno == logging.DEBUG:
  54. mcl, mtxt = self._color_dbg, "DBG"
  55. elif record.levelno == logging.WARNING:
  56. mcl, mtxt = self._color_warn, "WRN"
  57. elif record.levelno == logging.ERROR:
  58. mcl, mtxt = self._color_err, "ERR"
  59. else:
  60. mcl, mtxt = self._color_normal, ""
  61. if mtxt:
  62. mtxt += " "
  63. if self.log_fout:
  64. self.__set_fmt(self.date_full + mtxt + self.msg)
  65. formatted = super(MegEngineLogFormatter, self).format(record)
  66. nr_line = formatted.count("\n") + 1
  67. if nr_line >= self.max_lines:
  68. head, body = formatted.split("\n", 1)
  69. formatted = "\n".join(
  70. [
  71. head,
  72. "BEGIN_LONG_LOG_{}_LINES{{".format(nr_line - 1),
  73. body,
  74. "}}END_LONG_LOG_{}_LINES".format(nr_line - 1),
  75. ]
  76. )
  77. self.log_fout.write(formatted)
  78. self.log_fout.write("\n")
  79. self.log_fout.flush()
  80. self.__set_fmt(self._color_date(self.date) + mcl(mtxt + self.msg))
  81. formatted = super(MegEngineLogFormatter, self).format(record)
  82. if record.exc_text or record.exc_info:
  83. # handle exception format
  84. b = formatted.find("Traceback ")
  85. if b != -1:
  86. s = formatted[b:]
  87. s = self._color_exc(" " + s.replace("\n", "\n "))
  88. formatted = formatted[:b] + s
  89. nr_line = formatted.count("\n") + 1
  90. if nr_line >= self.max_lines:
  91. lines = formatted.split("\n")
  92. remain = self.max_lines // 2
  93. removed = len(lines) - remain * 2
  94. if removed > 0:
  95. mid_msg = self._color_omitted(
  96. "[{} log lines omitted (would be written to output file "
  97. "if set_log_file() has been called;\n"
  98. " the threshold can be set at "
  99. "MegEngineLogFormatter.max_lines)]".format(removed)
  100. )
  101. formatted = "\n".join(lines[:remain] + [mid_msg] + lines[-remain:])
  102. return formatted
  103. if sys.version_info.major < 3:
  104. def __set_fmt(self, fmt):
  105. self._fmt = fmt
  106. else:
  107. def __set_fmt(self, fmt):
  108. self._style._fmt = fmt
  109. def get_logger(name=None, formatter=MegEngineLogFormatter):
  110. r"""Gets megengine logger with given name."""
  111. logger = logging.getLogger(name)
  112. if getattr(logger, "_init_done__", None):
  113. return logger
  114. logger._init_done__ = True
  115. logger.propagate = False
  116. logger.setLevel(_default_level)
  117. handler = logging.StreamHandler()
  118. handler.setFormatter(formatter(datefmt="%d %H:%M:%S"))
  119. handler.setLevel(0)
  120. del logger.handlers[:]
  121. logger.addHandler(handler)
  122. _all_loggers.append(logger)
  123. return logger
  124. def set_log_level(level, update_existing=True):
  125. r"""Sets default logging level.
  126. Args:
  127. level: loggin level given by python :mod:`logging` module
  128. update_existing: whether to update existing loggers
  129. """
  130. global _default_level # pylint: disable=global-statement
  131. _default_level = level
  132. if update_existing:
  133. for i in _all_loggers:
  134. i.setLevel(level)
  135. _logger = get_logger(__name__)
  136. try:
  137. if sys.version_info.major < 3:
  138. raise ImportError()
  139. from .core._imperative_rt.utils import Logger as _imperative_rt_logger
  140. class MegBrainLogFormatter(MegEngineLogFormatter):
  141. date = "%(asctime)s[mgb] "
  142. def _color_date(self, msg):
  143. return "\x1b[33m{}\x1b[0m".format(msg)
  144. _megbrain_logger = get_logger("megbrain", MegBrainLogFormatter)
  145. _imperative_rt_logger.set_log_handler(_megbrain_logger)
  146. def set_mgb_log_level(level):
  147. r"""Sets megbrain log level
  148. Args:
  149. level: new log level
  150. Returns:
  151. original log level
  152. """
  153. _megbrain_logger.setLevel(level)
  154. if level == logging.getLevelName("ERROR"):
  155. rst = _imperative_rt_logger.set_log_level(
  156. _imperative_rt_logger.LogLevel.Error
  157. )
  158. elif level == logging.getLevelName("INFO"):
  159. rst = _imperative_rt_logger.set_log_level(
  160. _imperative_rt_logger.LogLevel.Info
  161. )
  162. else:
  163. rst = _imperative_rt_logger.set_log_level(
  164. _imperative_rt_logger.LogLevel.Debug
  165. )
  166. return rst
  167. set_mgb_log_level(_default_level)
  168. except ImportError as exc:
  169. def set_mgb_log_level(level):
  170. raise NotImplementedError("imperative_rt has not been imported")
  171. @contextlib.contextmanager
  172. def replace_mgb_log_level(level):
  173. r"""Replaces megbrain log level in a block and restore after exiting.
  174. Args:
  175. level: new log level
  176. """
  177. old = set_mgb_log_level(level)
  178. try:
  179. yield
  180. finally:
  181. set_mgb_log_level(old)
  182. def enable_debug_log():
  183. r"""Sets logging level to debug for all components."""
  184. set_log_level(logging.DEBUG)
  185. set_mgb_log_level(logging.DEBUG)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台