|
|
@@ -225,13 +225,13 @@ class Optimizer(metaclass=ABCMeta): |
|
|
|
param.grad.reset_zero() |
|
|
|
|
|
|
|
def bcast_param(self): |
|
|
|
key = 0 |
|
|
|
for group in self.param_groups: |
|
|
|
for param in group["params"]: |
|
|
|
bcast_param( |
|
|
|
param, |
|
|
|
"bcast_param_" + str(get_group_id()), |
|
|
|
is_root=(get_rank() == 0), |
|
|
|
param, "bcast_param_" + str(key), is_root=(get_rank() == 0), |
|
|
|
) |
|
|
|
key += 1 |
|
|
|
|
|
|
|
def state_dict(self) -> Dict: |
|
|
|
r"""Export the optimizer state. |
|
|
|