|
|
@@ -218,7 +218,7 @@ class AllreduceCallback: |
|
|
|
if len(self._packing_list[dtype]) == 0: |
|
|
|
return |
|
|
|
grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]] |
|
|
|
shapes = [p.shape for p in self._packing_list[dtype]] |
|
|
|
shapes = [p._tuple_shape for p in self._packing_list[dtype]] |
|
|
|
reduced_grads = pack_allreduce_split( |
|
|
|
grad_list, shapes, self._group, self._reduce_method |
|
|
|
) |
|
|
@@ -241,7 +241,7 @@ class AllreduceCallback: |
|
|
|
dtype_str = str(np.dtype(param.dtype)) |
|
|
|
dtype_size = np.dtype(param.dtype).itemsize |
|
|
|
self._packing_list[dtype_str].append(param) |
|
|
|
self._packing_size[dtype_str] += int(np.prod(param.shape)) * dtype_size |
|
|
|
self._packing_size[dtype_str] += int(np.prod(param._tuple_shape)) * dtype_size |
|
|
|
if self._packing_size[dtype_str] > self._param_pack_thd: |
|
|
|
self._pack(dtype_str) |
|
|
|
return self._futures_dict[param] |
|
|
|