|
|
@@ -211,6 +211,8 @@ class AllreduceCallback: |
|
|
|
self._grad_origin_device = dict() |
|
|
|
|
|
|
|
def _pack(self, dtype): |
|
|
|
if len(self._packing_list[dtype]) == 0: |
|
|
|
return |
|
|
|
grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]] |
|
|
|
shapes = [p.shape for p in self._packing_list[dtype]] |
|
|
|
reduced_grads = pack_allreduce_split( |
|
|
|