Browse Source

fix(mge/traced_module): fix some bugs for graph surgery

GitOrigin-RevId: 6328a84cbc
release-1.7
Megvii Engine Team 3 years ago
parent
commit
e6c271ae46
2 changed files with 139 additions and 77 deletions
  1. +112
    -75
      imperative/python/megengine/traced_module/traced_module.py
  2. +27
    -2
      imperative/python/test/unit/traced_module/test_modification.py

+ 112
- 75
imperative/python/megengine/traced_module/traced_module.py View File

@@ -122,18 +122,18 @@ def _is_leaf(node):
return isinstance(node, RawTensor) return isinstance(node, RawTensor)




_enable_node_to_tensor = False
_enable_graph_surgery_mode = False




def _convert_node_flag():
return _enable_node_to_tensor
def _graph_surgery_mode():
return _enable_graph_surgery_mode




def _set_convert_node_flag(flag: bool = False):
global _enable_node_to_tensor
pre_flag = _enable_node_to_tensor
_enable_node_to_tensor = flag
return pre_flag
def _set_graph_surgery_mode(mode: bool):
global _enable_graph_surgery_mode
pre_mode = _enable_graph_surgery_mode
_enable_graph_surgery_mode = mode
return pre_mode




def _node_to_tensor(*args, **kwargs): def _node_to_tensor(*args, **kwargs):
@@ -145,11 +145,11 @@ def _node_to_tensor(*args, **kwargs):
active_module_tracer().current_scope()._add_input(n) active_module_tracer().current_scope()._add_input(n)
value = n.value value = n.value
if value is None: if value is None:
flag = _set_convert_node_flag(False)
flag = _set_graph_surgery_mode(False)
unset_module_tracing() unset_module_tracing()
value = F.zeros(shape=n._shape, dtype=n._dtype) value = F.zeros(shape=n._shape, dtype=n._dtype)
set_module_tracing() set_module_tracing()
_set_convert_node_flag(flag)
_set_graph_surgery_mode(flag)
orig_n = NodeMixin.get(value, None) orig_n = NodeMixin.get(value, None)
if orig_n is None or "setitem" not in orig_n._name: if orig_n is None or "setitem" not in orig_n._name:
NodeMixin.wrap_safe(value, n) NodeMixin.wrap_safe(value, n)
@@ -180,17 +180,25 @@ def _tensor_to_node(tensors):




def _wrap_method_to_tensor_node(): def _wrap_method_to_tensor_node():
def _any_method(name):
def _any_method(name, func):
def _any(*args, **kwargs): def _any(*args, **kwargs):
args, kwargs = _node_to_tensor(*args, **kwargs)
attr = getattr(args[0], name)
outs = attr
if callable(attr):
outs = attr(*(args[1:]), **kwargs)
if name == "__setitem__":
_node_to_tensor(outs)
return None
outs = _tensor_to_node(outs)
if is_tracing_module() and _graph_surgery_mode():
args, kwargs = _node_to_tensor(*args, **kwargs)
attr = getattr(args[0], name)
outs = attr
if callable(attr):
outs = attr(*(args[1:]), **kwargs)
if name == "__setitem__":
_node_to_tensor(outs)
return None
outs = _tensor_to_node(outs)
return outs
else:
outs = func
if callable(func):
outs = func(*args, **kwargs)
if isinstance(func, property):
outs = func.__get__(*args, **kwargs)
return outs return outs


return _any return _any
@@ -199,9 +207,9 @@ def _wrap_method_to_tensor_node():
for method in get_tensor_wrapable_method(): for method in get_tensor_wrapable_method():
patch = PatchedFn(TensorNode, method) patch = PatchedFn(TensorNode, method)
if type(getattr(Tensor, method)) == property: if type(getattr(Tensor, method)) == property:
patch.set_func(property(_any_method(method)))
patch.set_func(property(_any_method(method, patch.origin_fn)))
else: else:
patch.set_func(_any_method(method))
patch.set_func(_any_method(method, patch.origin_fn))
tensor_method_patch.append(patch) tensor_method_patch.append(patch)
return tensor_method_patch return tensor_method_patch


@@ -209,7 +217,7 @@ def _wrap_method_to_tensor_node():
def _convert_node_and_tensor(orig_func): def _convert_node_and_tensor(orig_func):
@functools.wraps(orig_func) @functools.wraps(orig_func)
def _convert(*args, **kwargs): def _convert(*args, **kwargs):
if _convert_node_flag() and is_tracing_module():
if is_tracing_module() and _graph_surgery_mode():
args, kwargs = _node_to_tensor(*args, **kwargs) args, kwargs = _node_to_tensor(*args, **kwargs)
rst = orig_func(*args, **kwargs, method_func=_convert) rst = orig_func(*args, **kwargs, method_func=_convert)
rst = _tensor_to_node(rst) rst = _tensor_to_node(rst)
@@ -224,31 +232,35 @@ def _convert_node_and_tensor(orig_func):
def _wrap_mnode_getattr(orig_getattr): def _wrap_mnode_getattr(orig_getattr):
@functools.wraps(orig_getattr) @functools.wraps(orig_getattr)
def wraped_fn(self, name): def wraped_fn(self, name):
obj = self.owner
current_graph = active_module_tracer().current_scope()
if self.top_graph is not None:
current_graph._add_input(self)
attr = getattr(obj, name)
node = attr
if not isinstance(attr, TracedModuleBuilder):
if isinstance(attr, Module):
attr = TracedModuleBuilder(attr)
setattr(obj, name, attr)

if is_tracing_module() and _graph_surgery_mode():
obj = self.owner
current_graph = active_module_tracer().current_scope()
if self.top_graph is not None:
current_graph._add_input(self)
attr = getattr(obj, name)
node = attr
if not isinstance(attr, TracedModuleBuilder):
if isinstance(attr, Module):
attr = TracedModuleBuilder(attr)
setattr(obj, name, attr)

if isinstance(attr, (NodeMixin, RawTensor)):
NodeMixin.wrap(
attr,
lambda: GetAttr.make(
self,
type=NodeMixin.get_wrapped_type(attr),
attr_name=name,
name="",
),
)
if isinstance(attr, (NodeMixin, RawTensor)): if isinstance(attr, (NodeMixin, RawTensor)):
NodeMixin.wrap(
attr,
lambda: GetAttr.make(
self,
type=NodeMixin.get_wrapped_type(attr),
attr_name=name,
name="",
),
)
if isinstance(attr, (NodeMixin, RawTensor)):
node = NodeMixin.get(attr)
if isinstance(node, ModuleNode):
node._owner = weakref.ref(attr)
node = NodeMixin.get(attr)
if isinstance(node, ModuleNode) and isinstance(attr, (NodeMixin, Module)):
node._owner = weakref.ref(attr)
return node
else:
node = object.__getattribute__(self, name)
return node return node


return wraped_fn return wraped_fn
@@ -257,10 +269,13 @@ def _wrap_mnode_getattr(orig_getattr):
def _wrap_mnode_call(orig_call): def _wrap_mnode_call(orig_call):
@functools.wraps(orig_call) @functools.wraps(orig_call)
def wraped_fn(self, *args, **kwargs): def wraped_fn(self, *args, **kwargs):
obj = self.owner
if self.top_graph is not None:
active_module_tracer().current_scope()._add_input(self)
rst = obj(*args, **kwargs)
if is_tracing_module() and _graph_surgery_mode():
obj = self.owner
if self.top_graph is not None:
active_module_tracer().current_scope()._add_input(self)
rst = obj(*args, **kwargs)
else:
raise TypeError("'ModuleNode' object is not callable")
return rst return rst


return wraped_fn return wraped_fn
@@ -284,7 +299,7 @@ class _InsertExprs:
Node._set_next_id(node_id) Node._set_next_id(node_id)
Expr._set_next_id(expr_id) Expr._set_next_id(expr_id)
set_module_tracing() set_module_tracing()
_set_convert_node_flag(True)
_set_graph_surgery_mode(True)
assert active_module_tracer() is None assert active_module_tracer() is None
set_active_module_tracer( set_active_module_tracer(
module_tracer(lambda x: _convert_node_and_tensor(_wrapped_function(x))) module_tracer(lambda x: _convert_node_and_tensor(_wrapped_function(x)))
@@ -303,20 +318,30 @@ class _InsertExprs:
if va is not None: if va is not None:
return False return False
active_module_tracer().patcher.__exit__(ty, va, tr) active_module_tracer().patcher.__exit__(ty, va, tr)
_set_convert_node_flag(False)


while self._tensor_method_patch: while self._tensor_method_patch:
pf = self._tensor_method_patch.pop() pf = self._tensor_method_patch.pop()
pf.set_func(pf.origin_fn) pf.set_func(pf.origin_fn)


# delete ModuleNode.__call__ to avoid entering the
# ModuleNode.__init__ method when call a ModuleNode object.
delattr(ModuleNode, "__call__")

module = self.graph.inputs[0].owner module = self.graph.inputs[0].owner


for k, v in module.__dict__.items():
if isinstance(v, TracedModuleBuilder):
v = v.build()
setattr(module, k, v)
def build_traced_module(
module: TracedModuleBuilder, target_module: TracedModule
):
for k, v in module.__dict__.items():
if isinstance(v, TracedModuleBuilder):
traced_v = v.build()
build_traced_module(v, traced_v)
setattr(target_module, k, traced_v)

build_traced_module(module, module)


set_symbolic_shape(self.use_sym_shape) set_symbolic_shape(self.use_sym_shape)
_set_graph_surgery_mode(False)
set_active_module_tracer(None) set_active_module_tracer(None)
unset_module_tracing() unset_module_tracing()


@@ -435,7 +460,7 @@ class NameSpace:


def unassociate_name_with_obj(self, node: Node): def unassociate_name_with_obj(self, node: Node):
assert node.name in self.used_names assert node.name in self.used_names
assert self.used_names[node.name] is node
# assert self.used_names[node.name] is node
self._used_names[node.name] = None self._used_names[node.name] = None


@property @property
@@ -1365,6 +1390,8 @@ class TracedModuleBuilder(NodeMixin):
node.module_type = mod_type node.module_type = mod_type


return self._mod return self._mod
elif isinstance(self._mod, TracedModule) and _graph_surgery_mode():
return self._mod
else: else:
is_qat = isinstance(self._mod, QATModule) or ( is_qat = isinstance(self._mod, QATModule) or (
isinstance(self._mod, TracedModule) and self._mod.is_qat isinstance(self._mod, TracedModule) and self._mod.is_qat
@@ -1409,6 +1436,10 @@ class TracedModuleBuilder(NodeMixin):


def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
assert isinstance(self._mod, Module) assert isinstance(self._mod, Module)
is_graph_surgery_mode = _graph_surgery_mode()
if isinstance(self._mod, TracedModule) and is_graph_surgery_mode:
_set_graph_surgery_mode(False)

# prepare args and kwargs for inner graph # prepare args and kwargs for inner graph
if "method_func" in kwargs: if "method_func" in kwargs:
kwargs.pop("method_func") kwargs.pop("method_func")
@@ -1514,7 +1545,7 @@ class TracedModuleBuilder(NodeMixin):
) )
rst = type(self._mod).forward(*args, **kwargs) rst = type(self._mod).forward(*args, **kwargs)


if _convert_node_flag():
if _graph_surgery_mode():
rst = _node_to_tensor(rst)[0][0] rst = _node_to_tensor(rst)[0][0]


outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
@@ -1536,6 +1567,7 @@ class TracedModuleBuilder(NodeMixin):
callnode.add_outputs(outputs) callnode.add_outputs(outputs)
self._argdef_graph_map[callnode.arg_def] = self._body self._argdef_graph_map[callnode.arg_def] = self._body
self._argdef_outdef_map[callnode.arg_def] = out_def self._argdef_outdef_map[callnode.arg_def] = out_def
_set_graph_surgery_mode(is_graph_surgery_mode)
return rst return rst


def __setattr__(self, name, value): def __setattr__(self, name, value):
@@ -1556,7 +1588,7 @@ class TracedModuleBuilder(NodeMixin):
return active_module_tracer().patcher.wrap_fn(attr) return active_module_tracer().patcher.wrap_fn(attr)


if isinstance(attr, (List, Dict)): if isinstance(attr, (List, Dict)):
flag = _set_convert_node_flag(False)
flag = _set_graph_surgery_mode(False)
unset_module_tracing() unset_module_tracing()
has_module, m_container = replace_container_with_module_container(attr) has_module, m_container = replace_container_with_module_container(attr)
if m_container: if m_container:
@@ -1567,7 +1599,7 @@ class TracedModuleBuilder(NodeMixin):
" Module and Non-Module objects." " Module and Non-Module objects."
) )
set_module_tracing() set_module_tracing()
_set_convert_node_flag(flag)
_set_graph_surgery_mode(flag)


if isinstance(attr, Module): if isinstance(attr, Module):
attr = TracedModuleBuilder(attr) attr = TracedModuleBuilder(attr)
@@ -1628,20 +1660,25 @@ class _expr_iter:
self._visited_graph = set() self._visited_graph = set()


def __iter__(self): def __iter__(self):
for inp_node in self.graph.inputs:
yield inp_node.expr
for expr in self.graph._exprs:
if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
yield expr
if (
self.recursive
and expr.graph is not None
and id(expr.graph) not in self._visited_graph
):
yield from expr.graph.exprs(self.recursive)
self._visited_graph.add(id(expr.graph))
else:
yield expr
yield from self._gen_expr(self.graph)

def _gen_expr(self, graph: InternalGraph):
visit_inp = set()
for inp_node in graph.inputs:
if inp_node not in visit_inp:
yield inp_node.expr
visit_inp.add(inp_node)

for expr in graph._exprs:
yield expr
if (
self.recursive
and hasattr(expr, "graph")
and expr.graph is not None
and id(expr.graph) not in self._visited_graph
):
self._visited_graph.add(id(expr.graph))
yield from self._gen_expr(expr.graph)




class _node_iter: class _node_iter:


+ 27
- 2
imperative/python/test/unit/traced_module/test_modification.py View File

@@ -15,7 +15,7 @@ import megengine.functional as F
import megengine.module as M import megengine.module as M
import megengine.module.qat as qat import megengine.module.qat as qat
from megengine.module.identity import Identity from megengine.module.identity import Identity
from megengine.traced_module import trace_module
from megengine.traced_module import TracedModule, trace_module
from megengine.traced_module.expr import CallFunction, CallMethod, Expr, GetAttr, Input from megengine.traced_module.expr import CallFunction, CallMethod, Expr, GetAttr, Input
from megengine.traced_module.node import ModuleNode, Node, TensorNode from megengine.traced_module.node import ModuleNode, Node, TensorNode


@@ -182,7 +182,6 @@ def test_insert_module():
setattr(traced_module, "neg", Neg(name="neg")) setattr(traced_module, "neg", Neg(name="neg"))
setattr(traced_module, "neg2", Neg(name="neg")) setattr(traced_module, "neg2", Neg(name="neg"))
setattr(traced_module, "param", F.zeros((1,))) setattr(traced_module, "param", F.zeros((1,)))

with graph.insert_exprs(): with graph.insert_exprs():
neg_out = self.neg(relu_out) neg_out = self.neg(relu_out)
neg_out = self.neg2(relu_out) neg_out = self.neg2(relu_out)
@@ -199,6 +198,32 @@ def test_insert_module():
if isinstance(n, TensorNode): if isinstance(n, TensorNode):
assert n.value is None assert n.value is None


traced_module, x, expect = _init_module()
setattr(traced_module.block0, "neg", Neg(name=None))
graph = traced_module.graph
self = graph.inputs[0]
out_node = graph.outputs[0]
with graph.insert_exprs():
neg_out = self.block0.neg(out_node)
graph.replace_node({out_node: neg_out})
graph.compile()
np.testing.assert_allclose(expect, -traced_module(x), atol=1e-6)
assert isinstance(traced_module.block0.neg, TracedModule)
assert traced_module.block0.neg.graph is not None

setattr(traced_module.block0.neg, "neg", Neg(name=None))
setattr(traced_module.block0.neg.neg, "relu", M.ReLU())
out_node = graph.outputs[0]
with graph.insert_exprs():
neg_out = self.block0.neg.neg(out_node)
neg_out = self.block0.neg.neg(neg_out)
relu_out = self.block0.neg.neg.relu(neg_out)
graph.replace_node({out_node: relu_out})
graph.compile()
np.testing.assert_allclose(F.relu(-expect), traced_module(x), atol=1e-6)
assert isinstance(traced_module.block0.neg.neg, TracedModule)
assert traced_module.block0.neg.neg.graph is not None



def test_insert_qat_module(): def test_insert_qat_module():
class concat(qat.Concat): class concat(qat.Concat):


Loading…
Cancel
Save