Browse Source

fix(mge/utils): fix module states input is dict or others

GitOrigin-RevId: f9701b6134
release-1.7
Megvii Engine Team 3 years ago
parent
commit
38b7cfdec1
1 changed files with 15 additions and 1 deletions
  1. +15
    -1
      imperative/python/megengine/utils/module_stats.py

+ 15
- 1
imperative/python/megengine/utils/module_stats.py View File

@@ -437,7 +437,21 @@ def module_stats(
has_inputs = True
if not isinstance(inputs, (tuple, list)):
inputs = [inputs]
inputs = [Tensor(input, dtype=np.float32) for input in inputs]

def load_tensor(x):
if isinstance(x, np.ndarray):
return Tensor(x)
elif isinstance(x, collections.abc.Mapping):
return {k: load_tensor(x) for k, v in x.items()}
elif isinstance(x, tuple) and hasattr(x, "_fields"): # nametuple
return type(x)(*(load_tensor(value) for value in x))
elif isinstance(x, collections.abc.Sequence):
return [load_tensor(v) for v in x]
else:
return Tensor(x, dtype=np.float32)

inputs = load_tensor(inputs)

else:
if input_shapes:
if not isinstance(input_shapes[0], tuple):


Loading…
Cancel
Save