diff --git a/python_module/megengine/optimizer/optimizer.py b/python_module/megengine/optimizer/optimizer.py index 86c02c92..b89eb1d1 100644 --- a/python_module/megengine/optimizer/optimizer.py +++ b/python_module/megengine/optimizer/optimizer.py @@ -225,13 +225,13 @@ class Optimizer(metaclass=ABCMeta): param.grad.reset_zero() def bcast_param(self): + key = 0 for group in self.param_groups: for param in group["params"]: bcast_param( - param, - "bcast_param_" + str(get_group_id()), - is_root=(get_rank() == 0), + param, "bcast_param_" + str(key), is_root=(get_rank() == 0), ) + key += 1 def state_dict(self) -> Dict: r"""Export the optimizer state.