Browse Source

fix(mge/traced_module): fix treedef repr

GitOrigin-RevId: 3df05e9c22
tags/v1.8.0
Megvii Engine Team 3 years ago
parent
commit
bf0f3d31f5
2 changed files with 48 additions and 5 deletions
  1. +42
    -4
      imperative/python/megengine/traced_module/pytree.py
  2. +6
    -1
      imperative/python/megengine/traced_module/traced_module.py

+ 42
- 4
imperative/python/megengine/traced_module/pytree.py View File

@@ -10,7 +10,7 @@ import collections
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from functools import partial from functools import partial
from inspect import FullArgSpec from inspect import FullArgSpec
from typing import Any, Callable, List, NamedTuple, Tuple
from typing import Any, Callable, Dict, List, NamedTuple, Tuple


import numpy as np import numpy as np


@@ -284,8 +284,43 @@ class TreeDef:
and self.children_defs == other.children_defs and self.children_defs == other.children_defs
) )


def _args_kwargs_repr(self):
if (
len(self.children_defs) == 2
and issubclass(self.children_defs[0].type, (List, Tuple))
and issubclass(self.children_defs[1].type, Dict)
):
args_def = self.children_defs[0]
content = ", ".join(repr(i) for i in args_def.children_defs)
kwargs_def = self.children_defs[1]
if kwargs_def.aux_data:
content += ", "
content += ", ".join(
str(i) + "=" + repr(j)
for i, j in zip(kwargs_def.aux_data, kwargs_def.children_defs)
)
return content
else:
return repr(self)

def __repr__(self): def __repr__(self):
return "{}[{}]".format(self.type.__name__, self.children_defs)
format_str = self.type.__name__ + "({})"
aux_data_delimiter = "="
if issubclass(self.type, List):
format_str = "[{}]"
if issubclass(self.type, Tuple):
format_str = "({})"
if issubclass(self.type, Dict):
format_str = "{{{}}}"
aux_data_delimiter = ":"
if self.aux_data:
content = ", ".join(
repr(i) + aux_data_delimiter + repr(j)
for i, j in zip(self.aux_data, self.children_defs)
)
else:
content = ", ".join(repr(i) for i in self.children_defs)
return format_str.format(content)




class LeafDef(TreeDef): class LeafDef(TreeDef):
@@ -315,6 +350,9 @@ class LeafDef(TreeDef):
return hash(tuple([self.type, self.const_val])) return hash(tuple([self.type, self.const_val]))


def __repr__(self): def __repr__(self):
return "Leaf({}[{}])".format(
", ".join(t.__name__ for t in self.type), self.const_val

return "{}".format(
self.const_val
if self.const_val is not None or type(None) in self.type
else self.type[0].__name__
) )

+ 6
- 1
imperative/python/megengine/traced_module/traced_module.py View File

@@ -1977,7 +1977,12 @@ class TracedModule(Module):
if hasattr(self, "argspec") and self.argspec is not None: if hasattr(self, "argspec") and self.argspec is not None:
args, kwargs = _convert_kwargs_to_args(self.argspec, args, kwargs, True) args, kwargs = _convert_kwargs_to_args(self.argspec, args, kwargs, True)
inputs, treedef = tree_flatten(((self, *args), kwargs)) inputs, treedef = tree_flatten(((self, *args), kwargs))
assert treedef in self.argdef_graph_map
assert (
treedef in self.argdef_graph_map
), "support input args kwargs format: \n{}, but get: \n{}".format(
"\n ".join("forward({})".format(i._args_kwargs_repr()) for i in self.argdef_graph_map.keys()),
treedef._args_kwargs_repr(),
)
inputs = filter( inputs = filter(
lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs
) # allow TracedModuleBuilder for retrace. ) # allow TracedModuleBuilder for retrace.


Loading…
Cancel
Save