From 38b7cfdec12e9e65fbd733cf426c433bf485a346 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 16 Nov 2021 17:40:25 +0800 Subject: [PATCH] fix(mge/utils): fix module states input is dict or others GitOrigin-RevId: f9701b6134bf663345260e03f7f8a213a8fcb050 --- imperative/python/megengine/utils/module_stats.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/utils/module_stats.py b/imperative/python/megengine/utils/module_stats.py index fe46c0ea..bc86fb42 100644 --- a/imperative/python/megengine/utils/module_stats.py +++ b/imperative/python/megengine/utils/module_stats.py @@ -437,7 +437,21 @@ def module_stats( has_inputs = True if not isinstance(inputs, (tuple, list)): inputs = [inputs] - inputs = [Tensor(input, dtype=np.float32) for input in inputs] + + def load_tensor(x): + if isinstance(x, np.ndarray): + return Tensor(x) + elif isinstance(x, collections.abc.Mapping): + return {k: load_tensor(x) for k, v in x.items()} + elif isinstance(x, tuple) and hasattr(x, "_fields"): # nametuple + return type(x)(*(load_tensor(value) for value in x)) + elif isinstance(x, collections.abc.Sequence): + return [load_tensor(v) for v in x] + else: + return Tensor(x, dtype=np.float32) + + inputs = load_tensor(inputs) + else: if input_shapes: if not isinstance(input_shapes[0], tuple):