From b4f9703f4833a5ffb2b4fb4bbbf29a2a3a5584d9 Mon Sep 17 00:00:00 2001 From: Qsingle <1271808136@qq.com> Date: Sun, 8 May 2022 21:05:17 +0800 Subject: [PATCH] fix(overflow): fix the overflow of the long_scalars in network_node and module_stats --- imperative/python/megengine/utils/module_stats.py | 2 +- imperative/python/megengine/utils/network_node.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/imperative/python/megengine/utils/module_stats.py b/imperative/python/megengine/utils/module_stats.py index a0a3beea..29d311be 100644 --- a/imperative/python/megengine/utils/module_stats.py +++ b/imperative/python/megengine/utils/module_stats.py @@ -97,7 +97,7 @@ def flops_convNd(module: M.Conv2d, inputs, outputs): bias = 1 if module.bias is not None else 0 # N x Cout x H x W x (Cin x Kw x Kh + bias) return np.prod(outputs[0].shape) * ( - module.in_channels // module.groups * np.prod(module.kernel_size) + bias + float(module.in_channels // module.groups) * np.prod(module.kernel_size) + bias ) diff --git a/imperative/python/megengine/utils/network_node.py b/imperative/python/megengine/utils/network_node.py index 768792bb..7dfeffef 100644 --- a/imperative/python/megengine/utils/network_node.py +++ b/imperative/python/megengine/utils/network_node.py @@ -471,7 +471,7 @@ def flops_conv(opnode: ConvolutionForward, inputs, outputs): NCHW = np.prod(outputs[0].shape) bias = 1 if isinstance(opnode, ConvBiasForward) else 0 # N x Cout x H x W x (Cin x Kw x Kh) - return NCHW * (num_input * kw * kh + bias) + return NCHW * (float(num_input * kw * kh) + bias) @register_receptive_field(ConvolutionForward, ConvBiasForward)