|
|
@@ -178,7 +178,7 @@ class Optimizer(metaclass=ABCMeta): |
|
|
|
assert len(grads) == len(params) |
|
|
|
|
|
|
|
for param, grad in zip(params, grads): |
|
|
|
if is_distributed(): |
|
|
|
if is_distributed() and param.replica_mode: |
|
|
|
with opr_priority_scope(cg, -(2 ** 30)): |
|
|
|
# always run all_reduce_mean first except add_update |
|
|
|
grad = ( |
|
|
@@ -230,10 +230,14 @@ class Optimizer(metaclass=ABCMeta): |
|
|
|
key = 0 |
|
|
|
for group in self.param_groups: |
|
|
|
for param in group["params"]: |
|
|
|
bcast_param( |
|
|
|
param, "bcast_param_" + str(key), get_world_size(), get_rank() == 0, |
|
|
|
) |
|
|
|
key += 1 |
|
|
|
if param.replica_mode: |
|
|
|
bcast_param( |
|
|
|
param, |
|
|
|
"bcast_param_" + str(key), |
|
|
|
get_world_size(), |
|
|
|
get_rank() == 0, |
|
|
|
) |
|
|
|
key += 1 |
|
|
|
|
|
|
|
def state_dict(self) -> Dict: |
|
|
|
r"""Export the optimizer state. |
|
|
|