You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

logger.py 7.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. # -*- coding: utf-8 -*-
  2. import contextlib
  3. import logging
  4. import os
  5. import sys
  6. _all_loggers = []
  7. _default_level_name = os.getenv("MEGENGINE_LOGGING_LEVEL", "INFO")
  8. _default_level = logging.getLevelName(_default_level_name.upper())
  9. def set_log_file(fout, mode="a"):
  10. r"""Sets log output file.
  11. Args:
  12. fout: file-like object that supports write and flush, or string for the filename
  13. mode: specify the mode to open log file if *fout* is a string
  14. """
  15. if isinstance(fout, str):
  16. fout = open(fout, mode)
  17. MegEngineLogFormatter.log_fout = fout
  18. class MegEngineLogFormatter(logging.Formatter):
  19. log_fout = None
  20. date_full = "[%(asctime)s %(lineno)d@%(filename)s:%(name)s] "
  21. date = "%(asctime)s "
  22. msg = "%(message)s"
  23. max_lines = 256
  24. def _color_exc(self, msg):
  25. r"""Sets the color of message as the execution type."""
  26. return "\x1b[34m{}\x1b[0m".format(msg)
  27. def _color_dbg(self, msg):
  28. r"""Sets the color of message as the debugging type."""
  29. return "\x1b[36m{}\x1b[0m".format(msg)
  30. def _color_warn(self, msg):
  31. r"""Sets the color of message as the warning type."""
  32. return "\x1b[1;31m{}\x1b[0m".format(msg)
  33. def _color_err(self, msg):
  34. r"""Sets the color of message as the error type."""
  35. return "\x1b[1;4;31m{}\x1b[0m".format(msg)
  36. def _color_omitted(self, msg):
  37. r"""Sets the color of message as the omitted type."""
  38. return "\x1b[35m{}\x1b[0m".format(msg)
  39. def _color_normal(self, msg):
  40. r"""Sets the color of message as the normal type."""
  41. return msg
  42. def _color_date(self, msg):
  43. r"""Sets the color of message the same as date."""
  44. return "\x1b[32m{}\x1b[0m".format(msg)
  45. def format(self, record):
  46. if record.levelno == logging.DEBUG:
  47. mcl, mtxt = self._color_dbg, "DBG"
  48. elif record.levelno == logging.WARNING:
  49. mcl, mtxt = self._color_warn, "WRN"
  50. elif record.levelno == logging.ERROR:
  51. mcl, mtxt = self._color_err, "ERR"
  52. else:
  53. mcl, mtxt = self._color_normal, ""
  54. if mtxt:
  55. mtxt += " "
  56. if self.log_fout:
  57. self.__set_fmt(self.date_full + mtxt + self.msg)
  58. formatted = super(MegEngineLogFormatter, self).format(record)
  59. nr_line = formatted.count("\n") + 1
  60. if nr_line >= self.max_lines:
  61. head, body = formatted.split("\n", 1)
  62. formatted = "\n".join(
  63. [
  64. head,
  65. "BEGIN_LONG_LOG_{}_LINES{{".format(nr_line - 1),
  66. body,
  67. "}}END_LONG_LOG_{}_LINES".format(nr_line - 1),
  68. ]
  69. )
  70. self.log_fout.write(formatted)
  71. self.log_fout.write("\n")
  72. self.log_fout.flush()
  73. self.__set_fmt(self._color_date(self.date) + mcl(mtxt + self.msg))
  74. formatted = super(MegEngineLogFormatter, self).format(record)
  75. if record.exc_text or record.exc_info:
  76. # handle exception format
  77. b = formatted.find("Traceback ")
  78. if b != -1:
  79. s = formatted[b:]
  80. s = self._color_exc(" " + s.replace("\n", "\n "))
  81. formatted = formatted[:b] + s
  82. nr_line = formatted.count("\n") + 1
  83. if nr_line >= self.max_lines:
  84. lines = formatted.split("\n")
  85. remain = self.max_lines // 2
  86. removed = len(lines) - remain * 2
  87. if removed > 0:
  88. mid_msg = self._color_omitted(
  89. "[{} log lines omitted (would be written to output file "
  90. "if set_log_file() has been called;\n"
  91. " the threshold can be set at "
  92. "MegEngineLogFormatter.max_lines)]".format(removed)
  93. )
  94. formatted = "\n".join(lines[:remain] + [mid_msg] + lines[-remain:])
  95. return formatted
  96. if sys.version_info.major < 3:
  97. def __set_fmt(self, fmt):
  98. self._fmt = fmt
  99. else:
  100. def __set_fmt(self, fmt):
  101. self._style._fmt = fmt
  102. def get_logger(name=None, formatter=MegEngineLogFormatter):
  103. r"""Gets megengine logger with given name."""
  104. logger = logging.getLogger(name)
  105. if getattr(logger, "_init_done__", None):
  106. return logger
  107. logger._init_done__ = True
  108. logger.propagate = False
  109. logger.setLevel(_default_level)
  110. handler = logging.StreamHandler()
  111. handler.setFormatter(formatter(datefmt="%d %H:%M:%S"))
  112. handler.setLevel(0)
  113. del logger.handlers[:]
  114. logger.addHandler(handler)
  115. _all_loggers.append(logger)
  116. return logger
  117. def set_log_level(level, update_existing=True):
  118. r"""Sets default logging level.
  119. Args:
  120. level: loggin level given by python :mod:`logging` module
  121. update_existing: whether to update existing loggers
  122. """
  123. global _default_level # pylint: disable=global-statement
  124. origin_level = _default_level
  125. _default_level = level
  126. if update_existing:
  127. for i in _all_loggers:
  128. i.setLevel(level)
  129. return origin_level
  130. _logger = get_logger(__name__)
  131. try:
  132. if sys.version_info.major < 3:
  133. raise ImportError()
  134. from .core._imperative_rt.utils import Logger as _imperative_rt_logger
  135. class MegBrainLogFormatter(MegEngineLogFormatter):
  136. date = "%(asctime)s[mgb] "
  137. def _color_date(self, msg):
  138. return "\x1b[33m{}\x1b[0m".format(msg)
  139. _megbrain_logger = get_logger("megbrain", MegBrainLogFormatter)
  140. _imperative_rt_logger.set_log_handler(_megbrain_logger)
  141. def set_mgb_log_level(level):
  142. r"""Sets megbrain log level
  143. Args:
  144. level: new log level
  145. Returns:
  146. original log level
  147. """
  148. _megbrain_logger.setLevel(level)
  149. if level == logging.getLevelName("ERROR"):
  150. rst = _imperative_rt_logger.set_log_level(
  151. _imperative_rt_logger.LogLevel.Error
  152. )
  153. elif level == logging.getLevelName("INFO"):
  154. rst = _imperative_rt_logger.set_log_level(
  155. _imperative_rt_logger.LogLevel.Info
  156. )
  157. else:
  158. rst = _imperative_rt_logger.set_log_level(
  159. _imperative_rt_logger.LogLevel.Debug
  160. )
  161. return rst
  162. set_mgb_log_level(_default_level)
  163. except ImportError as exc:
  164. def set_mgb_log_level(level):
  165. raise NotImplementedError("imperative_rt has not been imported")
  166. @contextlib.contextmanager
  167. def replace_mgb_log_level(level):
  168. r"""Replaces megbrain log level in a block and restore after exiting.
  169. Args:
  170. level: new log level
  171. """
  172. old = set_mgb_log_level(level)
  173. try:
  174. yield
  175. finally:
  176. set_mgb_log_level(old)
  177. def enable_debug_log():
  178. r"""Sets logging level to debug for all components."""
  179. set_log_level(logging.DEBUG)
  180. set_mgb_log_level(logging.DEBUG)