Browse Source

fix(traced_module/doc): fix warning for traced_module docstring

GitOrigin-RevId: 5d2d047d2f
release-1.7
Megvii Engine Team 3 years ago
parent
commit
300955140e
4 changed files with 11 additions and 18 deletions
  1. +2
    -2
      imperative/python/megengine/traced_module/expr.py
  2. +4
    -11
      imperative/python/megengine/traced_module/node.py
  3. +3
    -3
      imperative/python/megengine/traced_module/pytree.py
  4. +2
    -2
      imperative/python/megengine/traced_module/traced_module.py

+ 2
- 2
imperative/python/megengine/traced_module/expr.py View File

@@ -159,13 +159,13 @@ class Expr:


@property @property
def kwargs(self): def kwargs(self):
r"""Get the the keyword arguments of the operation corresponding to this Expr."""
r"""Get the keyword arguments of the operation corresponding to this Expr."""
_, kwargs = self.unflatten_args(self.inputs) _, kwargs = self.unflatten_args(self.inputs)
return kwargs return kwargs


@property @property
def args(self): def args(self):
r"""Get the the positional arguments of the operation corresponding to this Expr."""
r"""Get the positional arguments of the operation corresponding to this Expr."""
args, _ = self.unflatten_args(self.inputs) args, _ = self.unflatten_args(self.inputs)
return args return args




+ 4
- 11
imperative/python/megengine/traced_module/node.py View File

@@ -33,7 +33,7 @@ class Node:
_orig_name = None # type: str _orig_name = None # type: str
_format_spec = "" # type: str _format_spec = "" # type: str


def __init__(self, expr: "Expr", name: str, orig_name: str):
def __init__(self, expr, name: str, orig_name: str):
self.expr = expr self.expr = expr
self.users = [] # List[Expr] self.users = [] # List[Expr]
self._id = Node.__total_id self._id = Node.__total_id
@@ -120,7 +120,7 @@ class ModuleNode(Node):
r"""The type of the Module correspending to the ModuleNode.""" r"""The type of the Module correspending to the ModuleNode."""
_owner = None # type: weakref.ReferenceType _owner = None # type: weakref.ReferenceType


def __init__(self, expr: "Expr", name: str = None, orig_name: str = None):
def __init__(self, expr, name: str = None, orig_name: str = None):
super().__init__(expr, name, orig_name) super().__init__(expr, name, orig_name)


def __getstate__(self): def __getstate__(self):
@@ -136,9 +136,6 @@ class ModuleNode(Node):
@property @property
def owner(self): def owner(self):
r"""Get the ``Module`` corresponding to this ``ModuleNode``. r"""Get the ``Module`` corresponding to this ``ModuleNode``.

Returns:
An :calss:`~.Module`.
""" """
if self._owner: if self._owner:
return self._owner() return self._owner()
@@ -196,7 +193,7 @@ class TensorNode(Node):


@property @property
def qparams(self): def qparams(self):
r"""Get the :calss:`QParams` of this Node."""
r"""Get the :class:`QParams` of this Node."""
return self._qparams return self._qparams


@qparams.setter @qparams.setter
@@ -210,11 +207,7 @@ class TensorNode(Node):


@value.setter @value.setter
def value(self, value): def value(self, value):
r"""Bind a Tensor to this Node.

Args:
value: A :class:`Tensor`.
"""
r"""Bind a :class:`Tensor` to this Node."""
if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None: if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None:
setattr(value, "_NodeMixin__node", None) setattr(value, "_NodeMixin__node", None)
self._value = value self._value = value


+ 3
- 3
imperative/python/megengine/traced_module/pytree.py View File

@@ -150,8 +150,8 @@ def tree_flatten(
is_leaf: Callable = _is_leaf, is_leaf: Callable = _is_leaf,
is_const_leaf: Callable = _is_const_leaf, is_const_leaf: Callable = _is_const_leaf,
): ):
r"""Flattens a object into a list of values and a :calss:`TreeDef` that can be used
to reconstruct the object.
r"""Flattens a pytree into a list of values and a :class:`TreeDef` that can be used
to reconstruct the pytree.
""" """
if type(values) not in SUPPORTED_TYPE: if type(values) not in SUPPORTED_TYPE:
assert is_leaf(values), values assert is_leaf(values), values
@@ -188,7 +188,7 @@ class TreeDef:
self.num_leaves = sum(ch.num_leaves for ch in children_defs) self.num_leaves = sum(ch.num_leaves for ch in children_defs)


def unflatten(self, leaves): def unflatten(self, leaves):
r"""Given a list of values and a ``TreeDef``, builds a object.
r"""Given a list of values and a ``TreeDef``, builds a pytree.
This is the inverse operation of ``tree_flatten``. This is the inverse operation of ``tree_flatten``.
""" """
assert len(leaves) == self.num_leaves assert len(leaves) == self.num_leaves


+ 2
- 2
imperative/python/megengine/traced_module/traced_module.py View File

@@ -453,7 +453,7 @@ class InternalGraph:
r"""Get the list of output Nodes of this graph. r"""Get the list of output Nodes of this graph.


Returns: Returns:
A list of Node.
A list of ``Node``.
""" """
return self._outputs return self._outputs


@@ -1937,7 +1937,7 @@ class TracedModule(Module):


@property @property
def graph(self) -> InternalGraph: def graph(self) -> InternalGraph:
"""Return the ``InternalGraph`` of this ``TracedModule``
"""Return the ``InternalGraph`` of this ``TracedModule``.
""" """
if self._is_top: if self._is_top:
self._update_ref() self._update_ref()


Loading…
Cancel
Save