|
|
@@ -211,6 +211,8 @@ class Patcher: |
|
|
|
self.wrap_fn = wrap_fn |
|
|
|
for module in self._builtin_modules: |
|
|
|
self.patch_module(module) |
|
|
|
# some functions in F.nn are import from other module, and not in __all__ |
|
|
|
self.auto_patch(F.nn.__dict__, False) |
|
|
|
for meth in BUILTIN_ARRAY_METHOD: |
|
|
|
self.patch_method(ArrayMethodMixin, meth, self.wrap_fn) |
|
|
|
self.patch_method(Tensor, "detach", self.wrap_fn) |
|
|
@@ -256,8 +258,8 @@ class Patcher: |
|
|
|
self.patch_function(module.__dict__, k, self.wrap_fn) |
|
|
|
self.visited_frames_ids.add(id(module.__dict__)) |
|
|
|
|
|
|
|
def auto_patch(self, frame_dict): |
|
|
|
if id(frame_dict) not in self.visited_frames_ids: |
|
|
|
def auto_patch(self, frame_dict, check_frame_id=True): |
|
|
|
if id(frame_dict) not in self.visited_frames_ids or not check_frame_id: |
|
|
|
for k, v in frame_dict.items(): |
|
|
|
if id(v) in self.patched_fn_ids: |
|
|
|
self.patch_function(frame_dict, k, self.wrap_fn) |
|
|
|