Browse Source

fix(traced_module): fix __getattr__ of TracedModuleBuilder

GitOrigin-RevId: 94d91d6938
release-1.6
Megvii Engine Team 3 years ago
parent
commit
91264f3797
3 changed files with 23 additions and 1 deletions
  1. +1
    -1
      imperative/python/megengine/traced_module/module_tracer.py
  2. +6
    -0
      imperative/python/megengine/traced_module/traced_module.py
  3. +16
    -0
      imperative/python/test/unit/traced_module/test_trace_module.py

+ 1
- 1
imperative/python/megengine/traced_module/module_tracer.py View File

@@ -185,7 +185,6 @@ class PatchedFn:

class Patcher:

patched_fn_ids = set()
_builtin_functions = []
_builtin_modules = [
F,
@@ -207,6 +206,7 @@ class Patcher:
]

def __init__(self, wrap_fn):
self.patched_fn_ids = set()
self.patched_fn = []
self.visited_frames_ids = set()
self.wrap_fn = wrap_fn


+ 6
- 0
imperative/python/megengine/traced_module/traced_module.py View File

@@ -17,6 +17,7 @@ import re
import weakref
from inspect import getcallargs, getmembers, isclass, ismethod
from itertools import chain
from types import FunctionType
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union

from megengine import tensor
@@ -1150,6 +1151,11 @@ class TracedModuleBuilder(NodeMixin):
else:
attr = getattr(self._mod, name)
full_name = None
if (
isinstance(attr, FunctionType)
and id(attr) in active_module_tracer().patcher.patched_fn_ids
):
return active_module_tracer().patcher.wrap_fn(attr)

if id(attr) in active_module_tracer().id2name:
full_name = active_module_tracer().id2name[id(attr)]


+ 16
- 0
imperative/python/test/unit/traced_module/test_trace_module.py View File

@@ -1,8 +1,10 @@
import numpy as np

import megengine.functional as F
import megengine.module as M
from megengine import Tensor
from megengine.traced_module import TracedModule, trace_module
from megengine.traced_module.expr import CallFunction


class MyModule1(M.Module):
@@ -38,6 +40,15 @@ class MyModule3(M.Module):
return y


class MyModule4(M.Module):
def __init__(self):
super().__init__()
self.add = F.add

def forward(self, x, y):
return self.add(x, y)


def test_trace_module():

x = Tensor(1)
@@ -67,3 +78,8 @@ def test_trace_module():
assert isinstance(tm3.modules.__dict__["0"], M.Elemwise)
assert isinstance(tm3.modules.__dict__["2"], TracedModule)
assert isinstance(tm3.modules.__dict__["2"].a, M.Elemwise)

m4 = MyModule4()
tm4 = trace_module(m4, a, b)
assert len(tm4.graph._exprs) == 1
assert isinstance(tm4.graph._exprs[0], CallFunction)

Loading…
Cancel
Save