GitOrigin-RevId: 94d91d6938
release-1.6
@@ -185,7 +185,6 @@ class PatchedFn: | |||||
class Patcher: | class Patcher: | ||||
patched_fn_ids = set() | |||||
_builtin_functions = [] | _builtin_functions = [] | ||||
_builtin_modules = [ | _builtin_modules = [ | ||||
F, | F, | ||||
@@ -207,6 +206,7 @@ class Patcher: | |||||
] | ] | ||||
def __init__(self, wrap_fn): | def __init__(self, wrap_fn): | ||||
self.patched_fn_ids = set() | |||||
self.patched_fn = [] | self.patched_fn = [] | ||||
self.visited_frames_ids = set() | self.visited_frames_ids = set() | ||||
self.wrap_fn = wrap_fn | self.wrap_fn = wrap_fn | ||||
@@ -17,6 +17,7 @@ import re | |||||
import weakref | import weakref | ||||
from inspect import getcallargs, getmembers, isclass, ismethod | from inspect import getcallargs, getmembers, isclass, ismethod | ||||
from itertools import chain | from itertools import chain | ||||
from types import FunctionType | |||||
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union | from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union | ||||
from megengine import tensor | from megengine import tensor | ||||
@@ -1150,6 +1151,11 @@ class TracedModuleBuilder(NodeMixin): | |||||
else: | else: | ||||
attr = getattr(self._mod, name) | attr = getattr(self._mod, name) | ||||
full_name = None | full_name = None | ||||
if ( | |||||
isinstance(attr, FunctionType) | |||||
and id(attr) in active_module_tracer().patcher.patched_fn_ids | |||||
): | |||||
return active_module_tracer().patcher.wrap_fn(attr) | |||||
if id(attr) in active_module_tracer().id2name: | if id(attr) in active_module_tracer().id2name: | ||||
full_name = active_module_tracer().id2name[id(attr)] | full_name = active_module_tracer().id2name[id(attr)] | ||||
@@ -1,8 +1,10 @@ | |||||
import numpy as np | import numpy as np | ||||
import megengine.functional as F | |||||
import megengine.module as M | import megengine.module as M | ||||
from megengine import Tensor | from megengine import Tensor | ||||
from megengine.traced_module import TracedModule, trace_module | from megengine.traced_module import TracedModule, trace_module | ||||
from megengine.traced_module.expr import CallFunction | |||||
class MyModule1(M.Module): | class MyModule1(M.Module): | ||||
@@ -38,6 +40,15 @@ class MyModule3(M.Module): | |||||
return y | return y | ||||
class MyModule4(M.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.add = F.add | |||||
def forward(self, x, y): | |||||
return self.add(x, y) | |||||
def test_trace_module(): | def test_trace_module(): | ||||
x = Tensor(1) | x = Tensor(1) | ||||
@@ -67,3 +78,8 @@ def test_trace_module(): | |||||
assert isinstance(tm3.modules.__dict__["0"], M.Elemwise) | assert isinstance(tm3.modules.__dict__["0"], M.Elemwise) | ||||
assert isinstance(tm3.modules.__dict__["2"], TracedModule) | assert isinstance(tm3.modules.__dict__["2"], TracedModule) | ||||
assert isinstance(tm3.modules.__dict__["2"].a, M.Elemwise) | assert isinstance(tm3.modules.__dict__["2"].a, M.Elemwise) | ||||
m4 = MyModule4() | |||||
tm4 = trace_module(m4, a, b) | |||||
assert len(tm4.graph._exprs) == 1 | |||||
assert isinstance(tm4.graph._exprs[0], CallFunction) |