|
|
@@ -88,7 +88,6 @@ class AllreduceCallback: |
|
|
|
self._futures_dict = dict() |
|
|
|
self._packing_list = defaultdict(list) |
|
|
|
self._packing_size = defaultdict(int) |
|
|
|
self._grad_origin_device = dict() |
|
|
|
|
|
|
|
def _pack(self, dtype): |
|
|
|
grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]] |
|
|
@@ -110,7 +109,6 @@ class AllreduceCallback: |
|
|
|
self._params.append(param) |
|
|
|
self._futures_dict[param] = TensorFuture(ack=False) |
|
|
|
self._gradients_dict[param] = grad |
|
|
|
self._grad_origin_device[param] = str(grad.device) |
|
|
|
|
|
|
|
dtype_str = str(np.dtype(param.dtype)) |
|
|
|
dtype_size = np.dtype(param.dtype).itemsize |
|
|
@@ -125,7 +123,6 @@ class AllreduceCallback: |
|
|
|
self._pack(dtype) |
|
|
|
for param in self._params: |
|
|
|
grad = self._gradients_dict[param] |
|
|
|
grad = copy(grad, self._grad_origin_device[param]) |
|
|
|
self._futures_dict[param].set(grad) |
|
|
|
self._reset() |
|
|
|
|
|
|
|