Browse Source

feat(mge/distributed): add parameter replica_mode

GitOrigin-RevId: 244e4ca437
release-0.6
Megvii Engine Team 5 years ago
parent
commit
2b3a88d011
2 changed files with 12 additions and 5 deletions
  1. +3
    -0
      python_module/megengine/core/tensor_nn.py
  2. +9
    -5
      python_module/megengine/optimizer/optimizer.py

+ 3
- 0
python_module/megengine/core/tensor_nn.py View File

@@ -31,6 +31,9 @@ class Parameter(Tensor):
t = tensor(value, dtype=dtype, device=device, requires_grad=requires_grad)
self.__dict__.update(t.__dict__)

# broadcast and allreduce will not be performed in optimizer if replica_mode is False
self.replica_mode = True

@property
def shape(self):
r"""Return shape of parameter.


+ 9
- 5
python_module/megengine/optimizer/optimizer.py View File

@@ -178,7 +178,7 @@ class Optimizer(metaclass=ABCMeta):
assert len(grads) == len(params)

for param, grad in zip(params, grads):
if is_distributed():
if is_distributed() and param.replica_mode:
with opr_priority_scope(cg, -(2 ** 30)):
# always run all_reduce_mean first except add_update
grad = (
@@ -230,10 +230,14 @@ class Optimizer(metaclass=ABCMeta):
key = 0
for group in self.param_groups:
for param in group["params"]:
bcast_param(
param, "bcast_param_" + str(key), get_world_size(), get_rank() == 0,
)
key += 1
if param.replica_mode:
bcast_param(
param,
"bcast_param_" + str(key),
get_world_size(),
get_rank() == 0,
)
key += 1

def state_dict(self) -> Dict:
r"""Export the optimizer state.


Loading…
Cancel
Save