|
|
@@ -437,7 +437,21 @@ def module_stats( |
|
|
|
has_inputs = True |
|
|
|
if not isinstance(inputs, (tuple, list)): |
|
|
|
inputs = [inputs] |
|
|
|
inputs = [Tensor(input, dtype=np.float32) for input in inputs] |
|
|
|
|
|
|
|
def load_tensor(x): |
|
|
|
if isinstance(x, np.ndarray): |
|
|
|
return Tensor(x) |
|
|
|
elif isinstance(x, collections.abc.Mapping): |
|
|
|
return {k: load_tensor(x) for k, v in x.items()} |
|
|
|
elif isinstance(x, tuple) and hasattr(x, "_fields"): # nametuple |
|
|
|
return type(x)(*(load_tensor(value) for value in x)) |
|
|
|
elif isinstance(x, collections.abc.Sequence): |
|
|
|
return [load_tensor(v) for v in x] |
|
|
|
else: |
|
|
|
return Tensor(x, dtype=np.float32) |
|
|
|
|
|
|
|
inputs = load_tensor(inputs) |
|
|
|
|
|
|
|
else: |
|
|
|
if input_shapes: |
|
|
|
if not isinstance(input_shapes[0], tuple): |
|
|
|