You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pretty_printers.py 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. import gdb
  2. import gdb.printing
  3. import gdb.types
  4. def eval_on_val(val, eval_str):
  5. eval_str = "(*({}*)({})).{}".format(val.type, val.address, eval_str)
  6. return gdb.parse_and_eval(eval_str)
  7. class SmallVectorPrinter:
  8. def __init__(self, val):
  9. t = val.type.template_argument(0)
  10. self.begin = val['m_begin_ptr'].cast(t.pointer())
  11. self.end = val['m_end_ptr'].cast(t.pointer())
  12. self.size = self.end - self.begin
  13. self.capacity = val['m_capacity_ptr'].cast(t.pointer()) - val['m_begin_ptr'].cast(t.pointer())
  14. def to_string(self):
  15. return 'SmallVector of Size {}'.format(self.size)
  16. def display_hint(self):
  17. return 'array'
  18. def children(self):
  19. for i in range(self.size):
  20. yield "[{}]".format(i), (self.begin+i).dereference()
  21. class MaybePrinter:
  22. def __init__(self, val):
  23. self.val = val['m_ptr']
  24. def to_string(self):
  25. if self.val:
  26. return 'Some {}'.format(self.val)
  27. else:
  28. return 'None'
  29. def display_hint(self):
  30. return 'array'
  31. def children(self):
  32. if self.val:
  33. yield '[0]', self.val.dereference()
  34. class ToStringPrinter:
  35. def __init__(self, val):
  36. self.val = val
  37. def to_string(self):
  38. return eval_on_val(self.val, "to_string().c_str()").string()
  39. class ReprPrinter:
  40. def __init__(self, val):
  41. self.val = val
  42. def to_string(self):
  43. return eval_on_val(self.val, "repr().c_str()").string()
  44. class HandlePrinter:
  45. def __init__(self, val):
  46. impl = gdb.lookup_type("mgb::imperative::interpreter::intl::TensorInfo")
  47. self.val = val.cast(impl.pointer())
  48. def to_string(self):
  49. if self.val:
  50. return 'Handle of TensorInfo at {}'.format(self.val)
  51. else:
  52. return 'Empty Handle'
  53. def display_hint(self):
  54. return 'array'
  55. def children(self):
  56. if self.val:
  57. yield '[0]', self.val.dereference()
  58. def print_small_tensor(device_nd):
  59. size = device_nd["m_storage"]["m_size"]
  60. ndim = device_nd["m_layout"]["ndim"]
  61. dim0 = device_nd["m_layout"]["shape"][0]
  62. stride0 = device_nd["m_layout"]["stride"][0]
  63. dtype = device_nd["m_layout"]["dtype"]
  64. if size == 0:
  65. return "<empty>"
  66. if ndim > 1:
  67. return "<ndim > 1>"
  68. if dim0 > 64:
  69. return "<size tool large>"
  70. raw_ptr = device_nd["m_storage"]["m_data"]["_M_ptr"]
  71. dtype_name = dtype["m_trait"]["name"].string()
  72. dtype_map = {
  73. "Float32": (gdb.lookup_type("float"), float),
  74. "Int32": (gdb.lookup_type("int"), int),
  75. }
  76. if dtype_name not in dtype_map:
  77. return "<dtype unsupported>"
  78. else:
  79. ctype, pytype = dtype_map[dtype_name]
  80. ptr = raw_ptr.cast(ctype.pointer())
  81. array = []
  82. for i in range(dim0):
  83. array.append((pytype)((ptr + i * int(stride0)).dereference()))
  84. return str(array)
  85. class LogicalTensorDescPrinter:
  86. def __init__(self, val):
  87. self.layout = val['layout']
  88. self.comp_node = val['comp_node']
  89. self.value = val['value']
  90. def to_string(self):
  91. return 'LogicalTensorDesc'
  92. def children(self):
  93. yield 'layout', self.layout
  94. yield 'comp_node', self.comp_node
  95. yield 'value', print_small_tensor(self.value)
  96. class OpDefPrinter:
  97. def __init__(self, val):
  98. self.val = val
  99. def to_string(self):
  100. return self.val.dynamic_type.name
  101. def children(self):
  102. concrete_val = self.val.address.cast(self.val.dynamic_type.pointer()).dereference()
  103. for field in concrete_val.type.fields():
  104. if field.is_base_class or field.artificial:
  105. continue
  106. if field.name == 'sm_typeinfo':
  107. continue
  108. yield field.name, concrete_val[field.name]
  109. pp = gdb.printing.RegexpCollectionPrettyPrinter("MegEngine")
  110. # megdnn
  111. pp.add_printer('megdnn::SmallVectorImpl', '^megdnn::SmallVector(Impl)?<.*>$', SmallVectorPrinter)
  112. pp.add_printer('megdnn::TensorLayout', '^megdnn::TensorLayout$', ToStringPrinter)
  113. pp.add_printer('megdnn::TensorShape', '^megdnn::TensorShape$', ToStringPrinter)
  114. # megbrain
  115. pp.add_printer('mgb::CompNode', '^mgb::CompNode$', ToStringPrinter)
  116. pp.add_printer('mgb::Maybe', '^mgb::Maybe<.*>$', MaybePrinter)
  117. # imperative
  118. pp.add_printer('mgb::imperative::LogicalTensorDesc', '^mgb::imperative::LogicalTensorDesc$', LogicalTensorDescPrinter)
  119. pp.add_printer('mgb::imperative::OpDef', '^mgb::imperative::OpDef$', OpDefPrinter)
  120. pp.add_printer('mgb::imperative::Subgraph', '^mgb::imperative::Subgraph$', ReprPrinter)
  121. pp.add_printer('mgb::imperative::EncodedSubgraph', '^mgb::imperative::EncodedSubgraph$', ReprPrinter)
  122. gdb.printing.register_pretty_printer(gdb.current_objfile(), pp)
  123. def override_pretty_printer_for(val):
  124. type = val.type.strip_typedefs()
  125. if type.code == gdb.TYPE_CODE_PTR:
  126. if not val:
  127. return None
  128. target_typename = str(type.target().strip_typedefs())
  129. if target_typename == "mgb::imperative::OpDef":
  130. return OpDefPrinter(val.dereference())
  131. if target_typename == "mgb::imperative::interpreter::Interpreter::HandleImpl":
  132. return HandlePrinter(val)
  133. gdb.pretty_printers.append(override_pretty_printer_for)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台