You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

exc.py 2.0 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. """exception handling"""
  10. from . import mgb as _mgb
  11. class MegBrainError(Exception):
  12. """exception class used by megbrain library"""
  13. tracker = None
  14. """the tracker setup by :func:`.set_exc_opr_tracker` when the related
  15. operator is created"""
  16. tracker_grad_orig = None
  17. """if this operator is created by taking gradient, this var would be the
  18. tracker of the operator that causes the grad."""
  19. def __init__(self, msg, tracker, tracker_grad_orig):
  20. assert isinstance(msg, str)
  21. super().__init__(msg, tracker, tracker_grad_orig)
  22. self.tracker = tracker
  23. self.tracker_grad_orig = tracker_grad_orig
  24. @classmethod
  25. def _format_tracker(cls, tracker):
  26. return ("| " + i for i in str(tracker).split("\n"))
  27. def __str__(self):
  28. lines = []
  29. lines.extend(self.args[0].split("\n"))
  30. if self.tracker is not None:
  31. lines.append("Exception tracker:")
  32. lines.extend(self._format_tracker(self.tracker))
  33. if self.tracker_grad_orig is not None:
  34. lines.append(
  35. "Exception caused by taking grad of another operator with tracker:"
  36. )
  37. lines.extend(self._format_tracker(self.tracker_grad_orig))
  38. while not lines[-1].strip():
  39. lines.pop()
  40. for idx, ct in enumerate(lines):
  41. if ct.startswith("bt:"):
  42. lines[idx] = "+ " + lines[idx]
  43. for t in range(idx + 1, len(lines)):
  44. lines[t] = "| " + lines[t]
  45. break
  46. return "\n".join(lines)
  47. _mgb._reg_exception_class(MegBrainError)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台