Browse Source

fix(mge/traced_module): patch missed functional in functional.nn

GitOrigin-RevId: 5aa1316fb1
release-1.7
Megvii Engine Team 3 years ago
parent
commit
b8316de5a5
1 changed files with 4 additions and 2 deletions
  1. +4
    -2
      imperative/python/megengine/traced_module/module_tracer.py

+ 4
- 2
imperative/python/megengine/traced_module/module_tracer.py View File

@@ -211,6 +211,8 @@ class Patcher:
self.wrap_fn = wrap_fn self.wrap_fn = wrap_fn
for module in self._builtin_modules: for module in self._builtin_modules:
self.patch_module(module) self.patch_module(module)
# some functions in F.nn are import from other module, and not in __all__
self.auto_patch(F.nn.__dict__, False)
for meth in BUILTIN_ARRAY_METHOD: for meth in BUILTIN_ARRAY_METHOD:
self.patch_method(ArrayMethodMixin, meth, self.wrap_fn) self.patch_method(ArrayMethodMixin, meth, self.wrap_fn)
self.patch_method(Tensor, "detach", self.wrap_fn) self.patch_method(Tensor, "detach", self.wrap_fn)
@@ -256,8 +258,8 @@ class Patcher:
self.patch_function(module.__dict__, k, self.wrap_fn) self.patch_function(module.__dict__, k, self.wrap_fn)
self.visited_frames_ids.add(id(module.__dict__)) self.visited_frames_ids.add(id(module.__dict__))


def auto_patch(self, frame_dict):
if id(frame_dict) not in self.visited_frames_ids:
def auto_patch(self, frame_dict, check_frame_id=True):
if id(frame_dict) not in self.visited_frames_ids or not check_frame_id:
for k, v in frame_dict.items(): for k, v in frame_dict.items():
if id(v) in self.patched_fn_ids: if id(v) in self.patched_fn_ids:
self.patch_function(frame_dict, k, self.wrap_fn) self.patch_function(frame_dict, k, self.wrap_fn)


Loading…
Cancel
Save