|
|
@@ -104,7 +104,7 @@ class AllreduceCallback: |
|
|
|
gm = get_backwarding_grad_manager() |
|
|
|
assert isinstance(gm, GradManager) |
|
|
|
if gm not in self._marked_gm: |
|
|
|
gm.register_after_backward_callback(self._flush) |
|
|
|
gm._register_after_backward_callback(self._flush) |
|
|
|
self._marked_gm.add(gm) |
|
|
|
self._params.append(param) |
|
|
|
self._futures_dict[param] = FakeTensor(ack=False) |
|
|
|