|
|
@@ -33,7 +33,7 @@ class Node: |
|
|
|
_orig_name = None # type: str |
|
|
|
_format_spec = "" # type: str |
|
|
|
|
|
|
|
def __init__(self, expr: "Expr", name: str, orig_name: str): |
|
|
|
def __init__(self, expr, name: str, orig_name: str): |
|
|
|
self.expr = expr |
|
|
|
self.users = [] # List[Expr] |
|
|
|
self._id = Node.__total_id |
|
|
@@ -120,7 +120,7 @@ class ModuleNode(Node): |
|
|
|
r"""The type of the Module correspending to the ModuleNode.""" |
|
|
|
_owner = None # type: weakref.ReferenceType |
|
|
|
|
|
|
|
def __init__(self, expr: "Expr", name: str = None, orig_name: str = None): |
|
|
|
def __init__(self, expr, name: str = None, orig_name: str = None): |
|
|
|
super().__init__(expr, name, orig_name) |
|
|
|
|
|
|
|
def __getstate__(self): |
|
|
@@ -136,9 +136,6 @@ class ModuleNode(Node): |
|
|
|
@property |
|
|
|
def owner(self): |
|
|
|
r"""Get the ``Module`` corresponding to this ``ModuleNode``. |
|
|
|
|
|
|
|
Returns: |
|
|
|
An :calss:`~.Module`. |
|
|
|
""" |
|
|
|
if self._owner: |
|
|
|
return self._owner() |
|
|
@@ -196,7 +193,7 @@ class TensorNode(Node): |
|
|
|
|
|
|
|
@property |
|
|
|
def qparams(self): |
|
|
|
r"""Get the :calss:`QParams` of this Node.""" |
|
|
|
r"""Get the :class:`QParams` of this Node.""" |
|
|
|
return self._qparams |
|
|
|
|
|
|
|
@qparams.setter |
|
|
@@ -210,11 +207,7 @@ class TensorNode(Node): |
|
|
|
|
|
|
|
@value.setter |
|
|
|
def value(self, value): |
|
|
|
r"""Bind a Tensor to this Node. |
|
|
|
|
|
|
|
Args: |
|
|
|
value: A :class:`Tensor`. |
|
|
|
""" |
|
|
|
r"""Bind a :class:`Tensor` to this Node.""" |
|
|
|
if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None: |
|
|
|
setattr(value, "_NodeMixin__node", None) |
|
|
|
self._value = value |
|
|
|