|
|
@@ -122,18 +122,18 @@ def _is_leaf(node): |
|
|
|
return isinstance(node, RawTensor) |
|
|
|
|
|
|
|
|
|
|
|
_enable_node_to_tensor = False |
|
|
|
_enable_graph_surgery_mode = False |
|
|
|
|
|
|
|
|
|
|
|
def _convert_node_flag(): |
|
|
|
return _enable_node_to_tensor |
|
|
|
def _graph_surgery_mode(): |
|
|
|
return _enable_graph_surgery_mode |
|
|
|
|
|
|
|
|
|
|
|
def _set_convert_node_flag(flag: bool = False): |
|
|
|
global _enable_node_to_tensor |
|
|
|
pre_flag = _enable_node_to_tensor |
|
|
|
_enable_node_to_tensor = flag |
|
|
|
return pre_flag |
|
|
|
def _set_graph_surgery_mode(mode: bool): |
|
|
|
global _enable_graph_surgery_mode |
|
|
|
pre_mode = _enable_graph_surgery_mode |
|
|
|
_enable_graph_surgery_mode = mode |
|
|
|
return pre_mode |
|
|
|
|
|
|
|
|
|
|
|
def _node_to_tensor(*args, **kwargs): |
|
|
@@ -145,11 +145,11 @@ def _node_to_tensor(*args, **kwargs): |
|
|
|
active_module_tracer().current_scope()._add_input(n) |
|
|
|
value = n.value |
|
|
|
if value is None: |
|
|
|
flag = _set_convert_node_flag(False) |
|
|
|
flag = _set_graph_surgery_mode(False) |
|
|
|
unset_module_tracing() |
|
|
|
value = F.zeros(shape=n._shape, dtype=n._dtype) |
|
|
|
set_module_tracing() |
|
|
|
_set_convert_node_flag(flag) |
|
|
|
_set_graph_surgery_mode(flag) |
|
|
|
orig_n = NodeMixin.get(value, None) |
|
|
|
if orig_n is None or "setitem" not in orig_n._name: |
|
|
|
NodeMixin.wrap_safe(value, n) |
|
|
@@ -180,17 +180,25 @@ def _tensor_to_node(tensors): |
|
|
|
|
|
|
|
|
|
|
|
def _wrap_method_to_tensor_node(): |
|
|
|
def _any_method(name): |
|
|
|
def _any_method(name, func): |
|
|
|
def _any(*args, **kwargs): |
|
|
|
args, kwargs = _node_to_tensor(*args, **kwargs) |
|
|
|
attr = getattr(args[0], name) |
|
|
|
outs = attr |
|
|
|
if callable(attr): |
|
|
|
outs = attr(*(args[1:]), **kwargs) |
|
|
|
if name == "__setitem__": |
|
|
|
_node_to_tensor(outs) |
|
|
|
return None |
|
|
|
outs = _tensor_to_node(outs) |
|
|
|
if is_tracing_module() and _graph_surgery_mode(): |
|
|
|
args, kwargs = _node_to_tensor(*args, **kwargs) |
|
|
|
attr = getattr(args[0], name) |
|
|
|
outs = attr |
|
|
|
if callable(attr): |
|
|
|
outs = attr(*(args[1:]), **kwargs) |
|
|
|
if name == "__setitem__": |
|
|
|
_node_to_tensor(outs) |
|
|
|
return None |
|
|
|
outs = _tensor_to_node(outs) |
|
|
|
return outs |
|
|
|
else: |
|
|
|
outs = func |
|
|
|
if callable(func): |
|
|
|
outs = func(*args, **kwargs) |
|
|
|
if isinstance(func, property): |
|
|
|
outs = func.__get__(*args, **kwargs) |
|
|
|
return outs |
|
|
|
|
|
|
|
return _any |
|
|
@@ -199,9 +207,9 @@ def _wrap_method_to_tensor_node(): |
|
|
|
for method in get_tensor_wrapable_method(): |
|
|
|
patch = PatchedFn(TensorNode, method) |
|
|
|
if type(getattr(Tensor, method)) == property: |
|
|
|
patch.set_func(property(_any_method(method))) |
|
|
|
patch.set_func(property(_any_method(method, patch.origin_fn))) |
|
|
|
else: |
|
|
|
patch.set_func(_any_method(method)) |
|
|
|
patch.set_func(_any_method(method, patch.origin_fn)) |
|
|
|
tensor_method_patch.append(patch) |
|
|
|
return tensor_method_patch |
|
|
|
|
|
|
@@ -209,7 +217,7 @@ def _wrap_method_to_tensor_node(): |
|
|
|
def _convert_node_and_tensor(orig_func): |
|
|
|
@functools.wraps(orig_func) |
|
|
|
def _convert(*args, **kwargs): |
|
|
|
if _convert_node_flag() and is_tracing_module(): |
|
|
|
if is_tracing_module() and _graph_surgery_mode(): |
|
|
|
args, kwargs = _node_to_tensor(*args, **kwargs) |
|
|
|
rst = orig_func(*args, **kwargs, method_func=_convert) |
|
|
|
rst = _tensor_to_node(rst) |
|
|
@@ -224,31 +232,35 @@ def _convert_node_and_tensor(orig_func): |
|
|
|
def _wrap_mnode_getattr(orig_getattr): |
|
|
|
@functools.wraps(orig_getattr) |
|
|
|
def wraped_fn(self, name): |
|
|
|
obj = self.owner |
|
|
|
current_graph = active_module_tracer().current_scope() |
|
|
|
if self.top_graph is not None: |
|
|
|
current_graph._add_input(self) |
|
|
|
attr = getattr(obj, name) |
|
|
|
node = attr |
|
|
|
if not isinstance(attr, TracedModuleBuilder): |
|
|
|
if isinstance(attr, Module): |
|
|
|
attr = TracedModuleBuilder(attr) |
|
|
|
setattr(obj, name, attr) |
|
|
|
|
|
|
|
if is_tracing_module() and _graph_surgery_mode(): |
|
|
|
obj = self.owner |
|
|
|
current_graph = active_module_tracer().current_scope() |
|
|
|
if self.top_graph is not None: |
|
|
|
current_graph._add_input(self) |
|
|
|
attr = getattr(obj, name) |
|
|
|
node = attr |
|
|
|
if not isinstance(attr, TracedModuleBuilder): |
|
|
|
if isinstance(attr, Module): |
|
|
|
attr = TracedModuleBuilder(attr) |
|
|
|
setattr(obj, name, attr) |
|
|
|
|
|
|
|
if isinstance(attr, (NodeMixin, RawTensor)): |
|
|
|
NodeMixin.wrap( |
|
|
|
attr, |
|
|
|
lambda: GetAttr.make( |
|
|
|
self, |
|
|
|
type=NodeMixin.get_wrapped_type(attr), |
|
|
|
attr_name=name, |
|
|
|
name="", |
|
|
|
), |
|
|
|
) |
|
|
|
if isinstance(attr, (NodeMixin, RawTensor)): |
|
|
|
NodeMixin.wrap( |
|
|
|
attr, |
|
|
|
lambda: GetAttr.make( |
|
|
|
self, |
|
|
|
type=NodeMixin.get_wrapped_type(attr), |
|
|
|
attr_name=name, |
|
|
|
name="", |
|
|
|
), |
|
|
|
) |
|
|
|
if isinstance(attr, (NodeMixin, RawTensor)): |
|
|
|
node = NodeMixin.get(attr) |
|
|
|
if isinstance(node, ModuleNode): |
|
|
|
node._owner = weakref.ref(attr) |
|
|
|
node = NodeMixin.get(attr) |
|
|
|
if isinstance(node, ModuleNode) and isinstance(attr, (NodeMixin, Module)): |
|
|
|
node._owner = weakref.ref(attr) |
|
|
|
return node |
|
|
|
else: |
|
|
|
node = object.__getattribute__(self, name) |
|
|
|
return node |
|
|
|
|
|
|
|
return wraped_fn |
|
|
@@ -257,10 +269,13 @@ def _wrap_mnode_getattr(orig_getattr): |
|
|
|
def _wrap_mnode_call(orig_call): |
|
|
|
@functools.wraps(orig_call) |
|
|
|
def wraped_fn(self, *args, **kwargs): |
|
|
|
obj = self.owner |
|
|
|
if self.top_graph is not None: |
|
|
|
active_module_tracer().current_scope()._add_input(self) |
|
|
|
rst = obj(*args, **kwargs) |
|
|
|
if is_tracing_module() and _graph_surgery_mode(): |
|
|
|
obj = self.owner |
|
|
|
if self.top_graph is not None: |
|
|
|
active_module_tracer().current_scope()._add_input(self) |
|
|
|
rst = obj(*args, **kwargs) |
|
|
|
else: |
|
|
|
raise TypeError("'ModuleNode' object is not callable") |
|
|
|
return rst |
|
|
|
|
|
|
|
return wraped_fn |
|
|
@@ -284,7 +299,7 @@ class _InsertExprs: |
|
|
|
Node._set_next_id(node_id) |
|
|
|
Expr._set_next_id(expr_id) |
|
|
|
set_module_tracing() |
|
|
|
_set_convert_node_flag(True) |
|
|
|
_set_graph_surgery_mode(True) |
|
|
|
assert active_module_tracer() is None |
|
|
|
set_active_module_tracer( |
|
|
|
module_tracer(lambda x: _convert_node_and_tensor(_wrapped_function(x))) |
|
|
@@ -303,20 +318,30 @@ class _InsertExprs: |
|
|
|
if va is not None: |
|
|
|
return False |
|
|
|
active_module_tracer().patcher.__exit__(ty, va, tr) |
|
|
|
_set_convert_node_flag(False) |
|
|
|
|
|
|
|
while self._tensor_method_patch: |
|
|
|
pf = self._tensor_method_patch.pop() |
|
|
|
pf.set_func(pf.origin_fn) |
|
|
|
|
|
|
|
# delete ModuleNode.__call__ to avoid entering the |
|
|
|
# ModuleNode.__init__ method when call a ModuleNode object. |
|
|
|
delattr(ModuleNode, "__call__") |
|
|
|
|
|
|
|
module = self.graph.inputs[0].owner |
|
|
|
|
|
|
|
for k, v in module.__dict__.items(): |
|
|
|
if isinstance(v, TracedModuleBuilder): |
|
|
|
v = v.build() |
|
|
|
setattr(module, k, v) |
|
|
|
def build_traced_module( |
|
|
|
module: TracedModuleBuilder, target_module: TracedModule |
|
|
|
): |
|
|
|
for k, v in module.__dict__.items(): |
|
|
|
if isinstance(v, TracedModuleBuilder): |
|
|
|
traced_v = v.build() |
|
|
|
build_traced_module(v, traced_v) |
|
|
|
setattr(target_module, k, traced_v) |
|
|
|
|
|
|
|
build_traced_module(module, module) |
|
|
|
|
|
|
|
set_symbolic_shape(self.use_sym_shape) |
|
|
|
_set_graph_surgery_mode(False) |
|
|
|
set_active_module_tracer(None) |
|
|
|
unset_module_tracing() |
|
|
|
|
|
|
@@ -435,7 +460,7 @@ class NameSpace: |
|
|
|
|
|
|
|
def unassociate_name_with_obj(self, node: Node): |
|
|
|
assert node.name in self.used_names |
|
|
|
assert self.used_names[node.name] is node |
|
|
|
# assert self.used_names[node.name] is node |
|
|
|
self._used_names[node.name] = None |
|
|
|
|
|
|
|
@property |
|
|
@@ -1365,6 +1390,8 @@ class TracedModuleBuilder(NodeMixin): |
|
|
|
node.module_type = mod_type |
|
|
|
|
|
|
|
return self._mod |
|
|
|
elif isinstance(self._mod, TracedModule) and _graph_surgery_mode(): |
|
|
|
return self._mod |
|
|
|
else: |
|
|
|
is_qat = isinstance(self._mod, QATModule) or ( |
|
|
|
isinstance(self._mod, TracedModule) and self._mod.is_qat |
|
|
@@ -1409,6 +1436,10 @@ class TracedModuleBuilder(NodeMixin): |
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
|
assert isinstance(self._mod, Module) |
|
|
|
is_graph_surgery_mode = _graph_surgery_mode() |
|
|
|
if isinstance(self._mod, TracedModule) and is_graph_surgery_mode: |
|
|
|
_set_graph_surgery_mode(False) |
|
|
|
|
|
|
|
# prepare args and kwargs for inner graph |
|
|
|
if "method_func" in kwargs: |
|
|
|
kwargs.pop("method_func") |
|
|
@@ -1514,7 +1545,7 @@ class TracedModuleBuilder(NodeMixin): |
|
|
|
) |
|
|
|
rst = type(self._mod).forward(*args, **kwargs) |
|
|
|
|
|
|
|
if _convert_node_flag(): |
|
|
|
if _graph_surgery_mode(): |
|
|
|
rst = _node_to_tensor(rst)[0][0] |
|
|
|
|
|
|
|
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) |
|
|
@@ -1536,6 +1567,7 @@ class TracedModuleBuilder(NodeMixin): |
|
|
|
callnode.add_outputs(outputs) |
|
|
|
self._argdef_graph_map[callnode.arg_def] = self._body |
|
|
|
self._argdef_outdef_map[callnode.arg_def] = out_def |
|
|
|
_set_graph_surgery_mode(is_graph_surgery_mode) |
|
|
|
return rst |
|
|
|
|
|
|
|
def __setattr__(self, name, value): |
|
|
@@ -1556,7 +1588,7 @@ class TracedModuleBuilder(NodeMixin): |
|
|
|
return active_module_tracer().patcher.wrap_fn(attr) |
|
|
|
|
|
|
|
if isinstance(attr, (List, Dict)): |
|
|
|
flag = _set_convert_node_flag(False) |
|
|
|
flag = _set_graph_surgery_mode(False) |
|
|
|
unset_module_tracing() |
|
|
|
has_module, m_container = replace_container_with_module_container(attr) |
|
|
|
if m_container: |
|
|
@@ -1567,7 +1599,7 @@ class TracedModuleBuilder(NodeMixin): |
|
|
|
" Module and Non-Module objects." |
|
|
|
) |
|
|
|
set_module_tracing() |
|
|
|
_set_convert_node_flag(flag) |
|
|
|
_set_graph_surgery_mode(flag) |
|
|
|
|
|
|
|
if isinstance(attr, Module): |
|
|
|
attr = TracedModuleBuilder(attr) |
|
|
@@ -1628,20 +1660,25 @@ class _expr_iter: |
|
|
|
self._visited_graph = set() |
|
|
|
|
|
|
|
def __iter__(self): |
|
|
|
for inp_node in self.graph.inputs: |
|
|
|
yield inp_node.expr |
|
|
|
for expr in self.graph._exprs: |
|
|
|
if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode): |
|
|
|
yield expr |
|
|
|
if ( |
|
|
|
self.recursive |
|
|
|
and expr.graph is not None |
|
|
|
and id(expr.graph) not in self._visited_graph |
|
|
|
): |
|
|
|
yield from expr.graph.exprs(self.recursive) |
|
|
|
self._visited_graph.add(id(expr.graph)) |
|
|
|
else: |
|
|
|
yield expr |
|
|
|
yield from self._gen_expr(self.graph) |
|
|
|
|
|
|
|
def _gen_expr(self, graph: InternalGraph): |
|
|
|
visit_inp = set() |
|
|
|
for inp_node in graph.inputs: |
|
|
|
if inp_node not in visit_inp: |
|
|
|
yield inp_node.expr |
|
|
|
visit_inp.add(inp_node) |
|
|
|
|
|
|
|
for expr in graph._exprs: |
|
|
|
yield expr |
|
|
|
if ( |
|
|
|
self.recursive |
|
|
|
and hasattr(expr, "graph") |
|
|
|
and expr.graph is not None |
|
|
|
and id(expr.graph) not in self._visited_graph |
|
|
|
): |
|
|
|
self._visited_graph.add(id(expr.graph)) |
|
|
|
yield from self._gen_expr(expr.graph) |
|
|
|
|
|
|
|
|
|
|
|
class _node_iter: |
|
|
|