diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index 528b805e..fefd2a68 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -28,7 +28,7 @@ class GradManager: self._call_back_dict[id(p)].append(cb) return self - def register_after_backward_callback(self, callback): + def _register_after_backward_callback(self, callback): self._after_backward_callback.append(callback) return self diff --git a/imperative/python/megengine/distributed/helper.py b/imperative/python/megengine/distributed/helper.py index 693da1e8..aebc9b08 100644 --- a/imperative/python/megengine/distributed/helper.py +++ b/imperative/python/megengine/distributed/helper.py @@ -104,7 +104,7 @@ class AllreduceCallback: gm = get_backwarding_grad_manager() assert isinstance(gm, GradManager) if gm not in self._marked_gm: - gm.register_after_backward_callback(self._flush) + gm._register_after_backward_callback(self._flush) self._marked_gm.add(gm) self._params.append(param) self._futures_dict[param] = FakeTensor(ack=False)