Browse Source

fix(mge/traced_module): fix treedef repr

GitOrigin-RevId: 3df05e9c22
tags/v1.8.0
Megvii Engine Team 3 years ago
parent
commit
bf0f3d31f5
2 changed files with 48 additions and 5 deletions
  1. +42
    -4
      imperative/python/megengine/traced_module/pytree.py
  2. +6
    -1
      imperative/python/megengine/traced_module/traced_module.py

+ 42
- 4
imperative/python/megengine/traced_module/pytree.py View File

@@ -10,7 +10,7 @@ import collections
from collections import OrderedDict, defaultdict
from functools import partial
from inspect import FullArgSpec
from typing import Any, Callable, List, NamedTuple, Tuple
from typing import Any, Callable, Dict, List, NamedTuple, Tuple

import numpy as np

@@ -284,8 +284,43 @@ class TreeDef:
and self.children_defs == other.children_defs
)

def _args_kwargs_repr(self):
if (
len(self.children_defs) == 2
and issubclass(self.children_defs[0].type, (List, Tuple))
and issubclass(self.children_defs[1].type, Dict)
):
args_def = self.children_defs[0]
content = ", ".join(repr(i) for i in args_def.children_defs)
kwargs_def = self.children_defs[1]
if kwargs_def.aux_data:
content += ", "
content += ", ".join(
str(i) + "=" + repr(j)
for i, j in zip(kwargs_def.aux_data, kwargs_def.children_defs)
)
return content
else:
return repr(self)

def __repr__(self):
return "{}[{}]".format(self.type.__name__, self.children_defs)
format_str = self.type.__name__ + "({})"
aux_data_delimiter = "="
if issubclass(self.type, List):
format_str = "[{}]"
if issubclass(self.type, Tuple):
format_str = "({})"
if issubclass(self.type, Dict):
format_str = "{{{}}}"
aux_data_delimiter = ":"
if self.aux_data:
content = ", ".join(
repr(i) + aux_data_delimiter + repr(j)
for i, j in zip(self.aux_data, self.children_defs)
)
else:
content = ", ".join(repr(i) for i in self.children_defs)
return format_str.format(content)


class LeafDef(TreeDef):
@@ -315,6 +350,9 @@ class LeafDef(TreeDef):
return hash(tuple([self.type, self.const_val]))

def __repr__(self):
return "Leaf({}[{}])".format(
", ".join(t.__name__ for t in self.type), self.const_val

return "{}".format(
self.const_val
if self.const_val is not None or type(None) in self.type
else self.type[0].__name__
)

+ 6
- 1
imperative/python/megengine/traced_module/traced_module.py View File

@@ -1977,7 +1977,12 @@ class TracedModule(Module):
if hasattr(self, "argspec") and self.argspec is not None:
args, kwargs = _convert_kwargs_to_args(self.argspec, args, kwargs, True)
inputs, treedef = tree_flatten(((self, *args), kwargs))
assert treedef in self.argdef_graph_map
assert (
treedef in self.argdef_graph_map
), "support input args kwargs format: \n{}, but get: \n{}".format(
"\n ".join("forward({})".format(i._args_kwargs_repr()) for i in self.argdef_graph_map.keys()),
treedef._args_kwargs_repr(),
)
inputs = filter(
lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs
) # allow TracedModuleBuilder for retrace.


Loading…
Cancel
Save