GitOrigin-RevId: 5d2d047d2f
release-1.7
@@ -159,13 +159,13 @@ class Expr: | |||||
@property | @property | ||||
def kwargs(self): | def kwargs(self): | ||||
r"""Get the the keyword arguments of the operation corresponding to this Expr.""" | |||||
r"""Get the keyword arguments of the operation corresponding to this Expr.""" | |||||
_, kwargs = self.unflatten_args(self.inputs) | _, kwargs = self.unflatten_args(self.inputs) | ||||
return kwargs | return kwargs | ||||
@property | @property | ||||
def args(self): | def args(self): | ||||
r"""Get the the positional arguments of the operation corresponding to this Expr.""" | |||||
r"""Get the positional arguments of the operation corresponding to this Expr.""" | |||||
args, _ = self.unflatten_args(self.inputs) | args, _ = self.unflatten_args(self.inputs) | ||||
return args | return args | ||||
@@ -33,7 +33,7 @@ class Node: | |||||
_orig_name = None # type: str | _orig_name = None # type: str | ||||
_format_spec = "" # type: str | _format_spec = "" # type: str | ||||
def __init__(self, expr: "Expr", name: str, orig_name: str): | |||||
def __init__(self, expr, name: str, orig_name: str): | |||||
self.expr = expr | self.expr = expr | ||||
self.users = [] # List[Expr] | self.users = [] # List[Expr] | ||||
self._id = Node.__total_id | self._id = Node.__total_id | ||||
@@ -120,7 +120,7 @@ class ModuleNode(Node): | |||||
r"""The type of the Module correspending to the ModuleNode.""" | r"""The type of the Module correspending to the ModuleNode.""" | ||||
_owner = None # type: weakref.ReferenceType | _owner = None # type: weakref.ReferenceType | ||||
def __init__(self, expr: "Expr", name: str = None, orig_name: str = None): | |||||
def __init__(self, expr, name: str = None, orig_name: str = None): | |||||
super().__init__(expr, name, orig_name) | super().__init__(expr, name, orig_name) | ||||
def __getstate__(self): | def __getstate__(self): | ||||
@@ -136,9 +136,6 @@ class ModuleNode(Node): | |||||
@property | @property | ||||
def owner(self): | def owner(self): | ||||
r"""Get the ``Module`` corresponding to this ``ModuleNode``. | r"""Get the ``Module`` corresponding to this ``ModuleNode``. | ||||
Returns: | |||||
An :calss:`~.Module`. | |||||
""" | """ | ||||
if self._owner: | if self._owner: | ||||
return self._owner() | return self._owner() | ||||
@@ -196,7 +193,7 @@ class TensorNode(Node): | |||||
@property | @property | ||||
def qparams(self): | def qparams(self): | ||||
r"""Get the :calss:`QParams` of this Node.""" | |||||
r"""Get the :class:`QParams` of this Node.""" | |||||
return self._qparams | return self._qparams | ||||
@qparams.setter | @qparams.setter | ||||
@@ -210,11 +207,7 @@ class TensorNode(Node): | |||||
@value.setter | @value.setter | ||||
def value(self, value): | def value(self, value): | ||||
r"""Bind a Tensor to this Node. | |||||
Args: | |||||
value: A :class:`Tensor`. | |||||
""" | |||||
r"""Bind a :class:`Tensor` to this Node.""" | |||||
if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None: | if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None: | ||||
setattr(value, "_NodeMixin__node", None) | setattr(value, "_NodeMixin__node", None) | ||||
self._value = value | self._value = value | ||||
@@ -150,8 +150,8 @@ def tree_flatten( | |||||
is_leaf: Callable = _is_leaf, | is_leaf: Callable = _is_leaf, | ||||
is_const_leaf: Callable = _is_const_leaf, | is_const_leaf: Callable = _is_const_leaf, | ||||
): | ): | ||||
r"""Flattens a object into a list of values and a :calss:`TreeDef` that can be used | |||||
to reconstruct the object. | |||||
r"""Flattens a pytree into a list of values and a :class:`TreeDef` that can be used | |||||
to reconstruct the pytree. | |||||
""" | """ | ||||
if type(values) not in SUPPORTED_TYPE: | if type(values) not in SUPPORTED_TYPE: | ||||
assert is_leaf(values), values | assert is_leaf(values), values | ||||
@@ -188,7 +188,7 @@ class TreeDef: | |||||
self.num_leaves = sum(ch.num_leaves for ch in children_defs) | self.num_leaves = sum(ch.num_leaves for ch in children_defs) | ||||
def unflatten(self, leaves): | def unflatten(self, leaves): | ||||
r"""Given a list of values and a ``TreeDef``, builds a object. | |||||
r"""Given a list of values and a ``TreeDef``, builds a pytree. | |||||
This is the inverse operation of ``tree_flatten``. | This is the inverse operation of ``tree_flatten``. | ||||
""" | """ | ||||
assert len(leaves) == self.num_leaves | assert len(leaves) == self.num_leaves | ||||
@@ -453,7 +453,7 @@ class InternalGraph: | |||||
r"""Get the list of output Nodes of this graph. | r"""Get the list of output Nodes of this graph. | ||||
Returns: | Returns: | ||||
A list of Node. | |||||
A list of ``Node``. | |||||
""" | """ | ||||
return self._outputs | return self._outputs | ||||
@@ -1937,7 +1937,7 @@ class TracedModule(Module): | |||||
@property | @property | ||||
def graph(self) -> InternalGraph: | def graph(self) -> InternalGraph: | ||||
"""Return the ``InternalGraph`` of this ``TracedModule`` | |||||
"""Return the ``InternalGraph`` of this ``TracedModule``. | |||||
""" | """ | ||||
if self._is_top: | if self._is_top: | ||||
self._update_ref() | self._update_ref() | ||||