You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pretty_printers.py 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. import gdb
  2. import gdb.printing
  3. import gdb.types
  4. def dynamic_cast(val):
  5. assert val.type.code == gdb.TYPE_CODE_REF
  6. val = val.cast(val.dynamic_type)
  7. return val
  8. def eval_on_val(val, eval_str):
  9. if val.type.code == gdb.TYPE_CODE_REF:
  10. val = val.referenced_value()
  11. address = val.address
  12. eval_str = "(*({}){}){}".format(address.type, int(address), eval_str)
  13. return gdb.parse_and_eval(eval_str)
  14. class SmallVectorPrinter:
  15. def __init__(self, val):
  16. t = val.type.template_argument(0)
  17. self.begin = val['m_begin_ptr'].cast(t.pointer())
  18. self.end = val['m_end_ptr'].cast(t.pointer())
  19. self.size = self.end - self.begin
  20. self.capacity = val['m_capacity_ptr'].cast(t.pointer()) - val['m_begin_ptr'].cast(t.pointer())
  21. def to_string(self):
  22. return 'SmallVector of Size {}'.format(self.size)
  23. def display_hint(self):
  24. return 'array'
  25. def children(self):
  26. for i in range(self.size):
  27. yield "[{}]".format(i), (self.begin+i).dereference()
  28. class MaybePrinter:
  29. def __init__(self, val):
  30. self.val = val['m_ptr']
  31. def to_string(self):
  32. if self.val:
  33. return 'Some {}'.format(self.val)
  34. else:
  35. return 'None'
  36. def display_hint(self):
  37. return 'array'
  38. def children(self):
  39. if self.val:
  40. yield '[0]', self.val.dereference()
  41. class ToStringPrinter:
  42. def __init__(self, val):
  43. self.val = val
  44. def to_string(self):
  45. return eval_on_val(self.val, ".to_string().c_str()").string()
  46. class ReprPrinter:
  47. def __init__(self, val):
  48. self.val = val
  49. def to_string(self):
  50. val = self.val
  51. if val.type.code == gdb.TYPE_CODE_REF:
  52. val = val.referenced_value()
  53. return eval_on_val(val, ".repr().c_str()").string()
  54. class HandlePrinter:
  55. def __init__(self, val):
  56. impl = gdb.lookup_type("mgb::imperative::interpreter::intl::TensorInfo")
  57. self.val = val.cast(impl.pointer())
  58. def to_string(self):
  59. if self.val:
  60. return 'Handle of TensorInfo at {}'.format(self.val)
  61. else:
  62. return 'Empty Handle'
  63. def display_hint(self):
  64. return 'array'
  65. def children(self):
  66. if self.val:
  67. yield '[0]', self.val.dereference()
  68. def print_small_tensor(device_nd):
  69. size = device_nd["m_storage"]["m_size"]
  70. ndim = device_nd["m_layout"]["ndim"]
  71. dim0 = device_nd["m_layout"]["shape"][0]
  72. stride0 = device_nd["m_layout"]["stride"][0]
  73. dtype = device_nd["m_layout"]["dtype"]
  74. if size == 0:
  75. return "<empty>"
  76. if ndim > 1:
  77. return "<ndim > 1>"
  78. if dim0 > 64:
  79. return "<size tool large>"
  80. raw_ptr = device_nd["m_storage"]["m_data"]["_M_ptr"]
  81. dtype_name = dtype["m_trait"]["name"].string()
  82. dtype_map = {
  83. "Float32": (gdb.lookup_type("float"), float),
  84. "Int32": (gdb.lookup_type("int"), int),
  85. }
  86. if dtype_name not in dtype_map:
  87. return "<dtype unsupported>"
  88. else:
  89. ctype, pytype = dtype_map[dtype_name]
  90. ptr = raw_ptr.cast(ctype.pointer())
  91. array = []
  92. for i in range(dim0):
  93. array.append((pytype)((ptr + i * int(stride0)).dereference()))
  94. return str(array)
  95. class LogicalTensorDescPrinter:
  96. def __init__(self, val):
  97. self.layout = val['layout']
  98. self.comp_node = val['comp_node']
  99. self.value = val['value']
  100. def to_string(self):
  101. return 'LogicalTensorDesc'
  102. def children(self):
  103. yield 'layout', self.layout
  104. yield 'comp_node', self.comp_node
  105. yield 'value', print_small_tensor(self.value)
  106. class OpDefPrinter:
  107. def __init__(self, val):
  108. self.val = val
  109. def to_string(self):
  110. return self.val.dynamic_type.name
  111. def children(self):
  112. concrete_val = self.val.address.cast(self.val.dynamic_type.pointer()).dereference()
  113. for field in concrete_val.type.fields():
  114. if field.is_base_class or field.artificial:
  115. continue
  116. if field.name == 'sm_typeinfo':
  117. continue
  118. yield field.name, concrete_val[field.name]
  119. class SpanPrinter:
  120. def __init__(self, val):
  121. self.begin = val['m_begin']
  122. self.end = val['m_end']
  123. self.size = self.end - self.begin
  124. def to_string(self):
  125. return 'Span of Size {}'.format(self.size)
  126. def display_hint(self):
  127. return 'array'
  128. def children(self):
  129. for i in range(self.size):
  130. yield "[{}]".format(i), (self.begin+i).dereference()
  131. pp = gdb.printing.RegexpCollectionPrettyPrinter("MegEngine")
  132. # megdnn
  133. pp.add_printer('megdnn::SmallVectorImpl', '^megdnn::SmallVector(Impl)?<.*>$', SmallVectorPrinter)
  134. pp.add_printer('megdnn::TensorLayout', '^megdnn::TensorLayout$', ToStringPrinter)
  135. pp.add_printer('megdnn::TensorShape', '^megdnn::TensorShape$', ToStringPrinter)
  136. # megbrain
  137. pp.add_printer('mgb::CompNode', '^mgb::CompNode$', ToStringPrinter)
  138. pp.add_printer('mgb::Maybe', '^mgb::Maybe<.*>$', MaybePrinter)
  139. # imperative
  140. pp.add_printer('mgb::imperative::LogicalTensorDesc', '^mgb::imperative::LogicalTensorDesc$', LogicalTensorDescPrinter)
  141. pp.add_printer('mgb::imperative::OpDef', '^mgb::imperative::OpDef$', OpDefPrinter)
  142. pp.add_printer('mgb::imperative::Subgraph', '^mgb::imperative::Subgraph$', ReprPrinter)
  143. pp.add_printer('mgb::imperative::EncodedSubgraph', '^mgb::imperative::EncodedSubgraph$', ReprPrinter)
  144. # imperative dispatch
  145. pp.add_printer('mgb::imperative::ValueRef', '^mgb::imperative::ValueRef$', ToStringPrinter)
  146. pp.add_printer('mgb::imperative::Span', '^mgb::imperative::Span<.*>$', SpanPrinter)
  147. gdb.printing.register_pretty_printer(gdb.current_objfile(), pp)
  148. def override_pretty_printer_for(val):
  149. type = val.type.strip_typedefs()
  150. if type.code == gdb.TYPE_CODE_PTR:
  151. if not val:
  152. return None
  153. target_typename = str(type.target().strip_typedefs())
  154. if target_typename == "mgb::imperative::OpDef":
  155. return OpDefPrinter(val.dereference())
  156. if target_typename == "mgb::imperative::interpreter::Interpreter::HandleImpl":
  157. return HandlePrinter(val)
  158. gdb.pretty_printers.append(override_pretty_printer_for)