Browse Source

docs(mge): add missing docstring and fix sphinx build warnings

GitOrigin-RevId: 4ce73cfd80
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
4f3875eb4f
34 changed files with 362 additions and 119 deletions
  1. +1
    -0
      imperative/python/megengine/core/autodiff/grad.py
  2. +171
    -10
      imperative/python/megengine/core/tensor/array_method.py
  3. +3
    -3
      imperative/python/megengine/data/dataloader.py
  4. +10
    -6
      imperative/python/megengine/data/dataset/meta_dataset.py
  5. +4
    -1
      imperative/python/megengine/data/dataset/vision/cifar.py
  6. +15
    -14
      imperative/python/megengine/data/dataset/vision/folder.py
  7. +1
    -1
      imperative/python/megengine/data/dataset/vision/mnist.py
  8. +39
    -35
      imperative/python/megengine/data/sampler.py
  9. +13
    -0
      imperative/python/megengine/data/transform/vision/transform.py
  10. +1
    -0
      imperative/python/megengine/distributed/__init__.py
  11. +11
    -0
      imperative/python/megengine/distributed/group.py
  12. +6
    -6
      imperative/python/megengine/functional/nn.py
  13. +1
    -1
      imperative/python/megengine/module/batch_matmul_activation.py
  14. +2
    -2
      imperative/python/megengine/module/concat.py
  15. +6
    -7
      imperative/python/megengine/module/conv.py
  16. +4
    -5
      imperative/python/megengine/module/conv_bn.py
  17. +2
    -2
      imperative/python/megengine/module/elemwise.py
  18. +4
    -0
      imperative/python/megengine/module/qat/batch_matmul_activation.py
  19. +1
    -1
      imperative/python/megengine/module/qat/concat.py
  20. +2
    -2
      imperative/python/megengine/module/qat/conv.py
  21. +2
    -2
      imperative/python/megengine/module/qat/conv_bn.py
  22. +2
    -2
      imperative/python/megengine/module/qat/elemwise.py
  23. +1
    -1
      imperative/python/megengine/module/qat/linear.py
  24. +2
    -2
      imperative/python/megengine/module/qat/module.py
  25. +2
    -2
      imperative/python/megengine/module/qat/quant_dequant.py
  26. +2
    -0
      imperative/python/megengine/module/quantized/batch_matmul_activation.py
  27. +1
    -1
      imperative/python/megengine/module/quantized/concat.py
  28. +4
    -4
      imperative/python/megengine/module/quantized/conv.py
  29. +3
    -3
      imperative/python/megengine/module/quantized/conv_bn.py
  30. +1
    -1
      imperative/python/megengine/module/quantized/elemwise.py
  31. +1
    -1
      imperative/python/megengine/module/quantized/linear.py
  32. +2
    -2
      imperative/python/megengine/module/quantized/module.py
  33. +2
    -2
      imperative/python/megengine/module/quantized/quant_dequant.py
  34. +40
    -0
      imperative/python/megengine/tensor.py

+ 1
- 0
imperative/python/megengine/core/autodiff/grad.py View File

@@ -91,6 +91,7 @@ class Function(ops.PyOpBase):
Examples:

.. code-block::

class Sigmoid(Function):
def forward(self, x):
y = 1 / (1 + F.exp(-x))


+ 171
- 10
imperative/python/megengine/core/tensor/array_method.py View File

@@ -362,6 +362,9 @@ class ArrayMethodMixin(abc.ABC):

@property
def ndim(self):
r"""
Returns the number of dimensions of self :class:`~.Tensor`.
"""
shape = self._tuple_shape
if shape is None:
raise ValueError("unkown ndim")
@@ -369,6 +372,10 @@ class ArrayMethodMixin(abc.ABC):

@property
def size(self):
r"""
Returns the size of the self :class:`~.Tensor`.
The returned value is a subclass of :class:`tuple`.
"""
shape = self.shape
if shape.__class__ is tuple:
return np.prod(self.shape).item()
@@ -376,9 +383,16 @@ class ArrayMethodMixin(abc.ABC):

@property
def T(self):
r"""
alias of :attr:`~.Tensor.transpose`.
"""
return self.transpose()

def item(self, *args):
r"""
Returns the value of this :class:`~.Tensor` as a standard Python :class:`numbers.Number`.
This only works for tensors with one element. For other cases, see :meth:`~.tolist`.
"""
if not args:
if isinstance(self.size, int):
assert self.size == 1
@@ -386,12 +400,26 @@ class ArrayMethodMixin(abc.ABC):
return self[args].item()

def tolist(self):
r"""
Returns the tensor as a (nested) list.
For scalars, a standard Python number is returned, just like with :meth:`~.item`.
Tensors are automatically moved to the CPU first if necessary.

This operation is not differentiable.
"""
return self.numpy().tolist()

def astype(self, dtype):
r"""
Returns a :class:`Tensor` with the same data and number of elements
with the specified :attr:`~.Tensor.dtype`.
"""
return utils.astype(self, dtype)

def reshape(self, *args):
r"""
See :func:`~.reshape`.
"""
return _reshape(self, _expand_args(args))

# FIXME: remove this method
@@ -399,6 +427,9 @@ class ArrayMethodMixin(abc.ABC):
return _broadcast(self, _expand_args(args))

def transpose(self, *args):
r"""
See :func:`~.transpose`.
"""
if self.ndim == 0:
assert (
len(args) == 0
@@ -411,19 +442,22 @@ class ArrayMethodMixin(abc.ABC):
return _transpose(self, _expand_args(args))

def flatten(self):
r"""
See :func:`~.flatten`.
"""
return self.reshape(-1)

def sum(self, axis=None, keepdims: bool = False):
r"""
Returns the sum of each row of the input tensor in the given dimension ``axis``.
If ``axis`` is a list of axises, reduce over all of them.

If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, except in the dimension(s) ``axis`` where it is of size 1. Otherwise, ``axis`` is squeezed(see :meth:`~.functional.tensor.squeeze`).

Same for prod/mean/max/min.
If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor,
except in the dimension(s) ``axis`` where it is of size 1.
Otherwise, ``axis`` is squeezed (see :func:`~.squeeze`).

:param axis: the dimension or dimensions to reduce.
:param keepdim: whether the output tensor has ndim retained or not.
:param keepdims: whether the output tensor has ndim retained or not.
:return: output tensor.

Examples:
@@ -441,12 +475,139 @@ class ArrayMethodMixin(abc.ABC):
.. testoutput::

2
10.
10.0

"""
return _reduce("SUM")(self, axis, keepdims)

prod = _reduce("PRODUCT")
min = _reduce("MIN")
max = _reduce("MAX")
mean = _reduce("MEAN")
def prod(self, axis=None, keepdims: bool = False):
r"""
Returns the product of each row of the input tensor in the given dimension ``axis``.
If ``axis`` is a list of axises, reduce over all of them.
If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor,
except in the dimension(s) ``axis`` where it is of size 1.
Otherwise, ``axis`` is squeezed (see :func:`~.squeeze`).

:param axis: the dimension or dimensions to reduce.
:param keepdims: whether the output tensor has ndim retained or not.
:return: output tensor.

Examples:

.. testcode::

from megengine import tensor
a = tensor([False, True, True, False])
b = tensor([1.0, 2.0, 3.0, 4.0])
print(a.prod().numpy())
print(b.prod().numpy())

Outputs:

.. testoutput::

0
24.0

"""
return _reduce("PRODUCT")(self, axis, keepdims)

def min(self, axis=None, keepdims: bool = False):
r"""
Returns the min value of each row of the input tensor in the given dimension ``axis``.
If ``axis`` is a list of axises, reduce over all of them.
If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor,
except in the dimension(s) ``axis`` where it is of size 1.
Otherwise, ``axis`` is squeezed (see :func:`~.squeeze`).

:param axis: the dimension or dimensions to reduce.
:param keepdims: whether the output tensor has ndim retained or not.
:return: output tensor.

Examples:

.. testcode::

from megengine import tensor
a = tensor([False, True, True, False])
b = tensor([1.0, 2.0, 3.0, 4.0])
print(a.min().numpy())
print(b.min().numpy())

Outputs:

.. testoutput::

False
1.0

"""
return _reduce("MIN")(self, axis, keepdims)

def max(self, axis=None, keepdims: bool = False):
r"""
Returns the max value of each row of the input tensor in the given dimension ``axis``.
If ``axis`` is a list of axises, reduce over all of them.
If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor,
except in the dimension(s) ``axis`` where it is of size 1.
Otherwise, ``axis`` is squeezed (see :func:`~.squeeze`).

:param axis: the dimension or dimensions to reduce.
:param keepdims: whether the output tensor has ndim retained or not.
:return: output tensor.

Examples:

.. testcode::

from megengine import tensor
a = tensor([False, True, True, False])
b = tensor([1.0, 2.0, 3.0, 4.0])
print(a.max().numpy())
print(b.max().numpy())

Outputs:

.. testoutput::

True
4.0

"""
return _reduce("MAX")(self, axis, keepdims)

def mean(self, axis=None, keepdims: bool = False):
r"""
Returns the mean value of each row of the input tensor in the given dimension ``axis``.
If ``axis`` is a list of axises, reduce over all of them.
If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor,
except in the dimension(s) ``axis`` where it is of size 1.
Otherwise, ``axis`` is squeezed (see :func:`~.squeeze`).

:param axis: the dimension or dimensions to reduce.
:param keepdims: whether the output tensor has ndim retained or not.
:return: output tensor.

Examples:

.. testcode::

from megengine import tensor
a = tensor([False, True, True, False])
b = tensor([1.0, 2.0, 3.0, 4.0])
print(a.mean().numpy())
print(b.mean().numpy())

Outputs:

.. testoutput::

0.5
2.5

"""
return _reduce("MEAN")(self, axis, keepdims)

+ 3
- 3
imperative/python/megengine/data/dataloader.py View File

@@ -42,6 +42,9 @@ def raise_timeout_error():


class DataLoader:
r"""
Provides a convenient way to iterate on a given dataset.
"""
__initialized = False

def __init__(
@@ -56,8 +59,6 @@ class DataLoader:
divide: bool = False,
):
r"""
Provides a convenient way to iterate on a given dataset.

`DataLoader` combines a dataset with `sampler`, `transform` and `collator`,
make it flexible to get minibatch continually from a dataset.

@@ -87,7 +88,6 @@ class DataLoader:
different sub-process will process different batch. Default: False

"""

if num_workers < 0:
raise ValueError("num_workers should not be negative")



+ 10
- 6
imperative/python/megengine/data/dataset/meta_dataset.py View File

@@ -12,7 +12,8 @@ from typing import Tuple

class Dataset(ABC):
r"""
An abstract class for all datasets.
An abstract base class for all datasets.

__getitem__ and __len__ method are aditionally needed.
"""

@@ -32,6 +33,7 @@ class Dataset(ABC):
class StreamDataset(Dataset):
r"""
An abstract class for stream data.

__iter__ method is aditionally needed.
"""

@@ -51,12 +53,14 @@ class StreamDataset(Dataset):


class ArrayDataset(Dataset):
r"""
ArrayDataset is a dataset for numpy array data.

One or more numpy arrays are needed to initiate the dataset.
And the dimensions represented sample number are expected to be the same.
"""

def __init__(self, *arrays):
r"""
ArrayDataset is a dataset for numpy array data, one or more numpy arrays
are needed to initiate the dataset. And the dimensions represented sample number
are expected to be the same.
"""
super().__init__()
if not all(len(arrays[0]) == len(array) for array in arrays):
raise ValueError("lengths of input arrays are inconsistent")


+ 4
- 1
imperative/python/megengine/data/dataset/vision/cifar.py View File

@@ -21,7 +21,7 @@ logger = get_logger(__name__)


class CIFAR10(VisionDataset):
r""" ``Dataset`` for CIFAR10 meta data.
r""" :class:`~.Dataset` for CIFAR10 meta data.
"""

url_path = "http://www.cs.utoronto.ca/~kriz/"
@@ -138,6 +138,9 @@ class CIFAR10(VisionDataset):


class CIFAR100(CIFAR10):
r""" :class:`~.Dataset` for CIFAR100 meta data.
"""

url_path = "http://www.cs.utoronto.ca/~kriz/"
raw_file_name = "cifar-100-python.tar.gz"
raw_file_md5 = "eb9058c3a382ffc7106e4002c42a8d85"


+ 15
- 14
imperative/python/megengine/data/dataset/vision/folder.py View File

@@ -26,24 +26,25 @@ from .utils import is_img


class ImageFolder(VisionDataset):
def __init__(self, root: str, check_valid_func=None, class_name: bool = False):
r"""
ImageFolder is a class for loading image data and labels from a organized folder.
r"""
ImageFolder is a class for loading image data and labels from a organized folder.

The folder is expected to be organized as followed: root/cls/xxx.img_ext

The folder is expected to be organized as followed: root/cls/xxx.img_ext
Labels are indices of sorted classes in the root directory.

Labels are indices of sorted classes in the root directory.
:param root: root directory of an image folder.
:param loader: a function used to load image from path,
if ``None``, default function that loads
images with PIL will be called.
:param check_valid_func: a function used to check if files in folder are
expected image files, if ``None``, default function
that checks file extensions will be called.
:param class_name: if ``True``, return class name instead of class index.

:param root: root directory of an image folder.
:param loader: a function used to load image from path,
if ``None``, default function that loads
images with PIL will be called.
:param check_valid_func: a function used to check if files in folder are
expected image files, if ``None``, default function
that checks file extensions will be called.
:param class_name: if ``True``, return class name instead of class index.
"""

"""
def __init__(self, root: str, check_valid_func=None, class_name: bool = False):
super().__init__(root, order=("image", "image_category"))

self.root = root


+ 1
- 1
imperative/python/megengine/data/dataset/vision/mnist.py View File

@@ -22,7 +22,7 @@ logger = get_logger(__name__)


class MNIST(VisionDataset):
r""" ``Dataset`` for MNIST meta data.
r""" :class:`~.Dataset` for MNIST meta data.
"""

url_path = "http://yann.lecun.com/exdb/mnist/"


+ 39
- 35
imperative/python/megengine/data/sampler.py View File

@@ -18,7 +18,7 @@ import megengine.distributed as dist

class Sampler(ABC):
r"""
An abstract class for all Sampler
An abstract base class for all Sampler
"""

@abstractmethod
@@ -27,6 +27,28 @@ class Sampler(ABC):


class MapSampler(Sampler):
r"""
Sampler for map dataset.

:type dataset: `dataset`
:param dataset: dataset to sample from.
:type batch_size: positive integer
:param batch_size: batch size for batch method.
:type drop_last: bool
:param drop_last: set ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch_size, then the last batch will
be smaller. Default: False
:type num_samples: positive integer
:param num_samples: number of samples assigned to one rank.
:type world_size: positive integer
:param world_size: number of ranks.
:type rank: non-negative integer within 0 and world_size
:param rank: rank id, non-negative interger within 0 and ``world_size``.
:type seed: non-negative integer
:param seed: seed for random operators.
"""

def __init__(
self,
dataset,
@@ -37,27 +59,6 @@ class MapSampler(Sampler):
rank=None,
seed=None,
):
r"""
An abstract class for all sampler.

:type dataset: `dataset`
:param dataset: dataset to sample from.
:type batch_size: positive integer
:param batch_size: batch size for batch method.
:type drop_last: bool
:param drop_last: set ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch_size, then the last batch will
be smaller. Default: False
:type num_samples: positive integer
:param num_samples: number of samples assigned to one rank.
:type world_size: positive integer
:param world_size: number of ranks.
:type rank: non-negative integer within 0 and world_size
:param rank: rank id, non-negative interger within 0 and ``world_size``.
:type seed: non-negative integer
:param seed: seed for random operators.
"""
if (
not isinstance(batch_size, int)
or isinstance(batch_size, bool)
@@ -156,7 +157,7 @@ class MapSampler(Sampler):


class StreamSampler(Sampler):
"""
r"""
Sampler for stream dataset.

.. warning::
@@ -181,6 +182,10 @@ class StreamSampler(Sampler):


class SequentialSampler(MapSampler):
r"""
Sample elements sequentially.
"""

def __init__(
self,
dataset,
@@ -190,9 +195,6 @@ class SequentialSampler(MapSampler):
world_size=None,
rank=None,
):
r"""
Sample elements sequentially.
"""
super().__init__(dataset, batch_size, drop_last, None, world_size, rank)
if indices is not None and not isinstance(indices, collections.abc.Sequence):
raise ValueError(
@@ -212,6 +214,10 @@ class SequentialSampler(MapSampler):


class RandomSampler(MapSampler):
r"""
Sample elements randomly without replacement.
"""

def __init__(
self,
dataset,
@@ -222,9 +228,6 @@ class RandomSampler(MapSampler):
rank=None,
seed=None,
):
r"""
Sample elements randomly without replacement.
"""
super().__init__(dataset, batch_size, drop_last, None, world_size, rank, seed)
if indices is not None and not isinstance(indices, collections.abc.Sequence):
raise ValueError(
@@ -241,6 +244,13 @@ class RandomSampler(MapSampler):


class ReplacementSampler(MapSampler):
r"""
Sample elements randomly with replacement.

:type weights: List
:param weights: weights for sampling indices, it could be unnormalized weights.
"""

def __init__(
self,
dataset,
@@ -252,12 +262,6 @@ class ReplacementSampler(MapSampler):
rank=None,
seed=None,
):
r"""
Sample elements randomly with replacement.

:type weights: List
:param weights: weights for sampling indices, it could be unnormalized weights.
"""
super().__init__(
dataset, batch_size, drop_last, num_samples, world_size, rank, seed
)


+ 13
- 0
imperative/python/megengine/data/transform/vision/transform.py View File

@@ -410,6 +410,10 @@ class Resize(VisionTransform):


class ShortestEdgeResize(VisionTransform):
r"""
Resize the input data with specified shortset edge.
"""

def __init__(
self,
min_size,
@@ -1010,6 +1014,15 @@ class ColorJitter(VisionTransform):


class Lighting(VisionTransform):
r"""
Apply AlexNet-Style "lighting" augmentation to input data.

Input images are assumed to have 'RGB' channel order.

The degree of color jittering is randomly sampled via a normal distribution,
with standard deviation given by the scale parameter.
"""

def __init__(self, scale, *, order=None):
super().__init__(order)
if scale < 0:


+ 1
- 0
imperative/python/megengine/distributed/__init__.py View File

@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .group import (
WORLD,
Group,
get_backend,
get_client,
get_mm_server_addr,


+ 11
- 0
imperative/python/megengine/distributed/group.py View File

@@ -29,6 +29,17 @@ _sd = None


class Group:
r"""
Include ranked nodes running collective communication (See :mod:`~.functional.distributed`).

By default collectives operate on the default group (also called ``WORLD``)
and require all processes to enter the distributed function call.

:param proc_ranks: rank list of the group, the first one is root rank.

"""

def __init__(self, proc_ranks):
if len(proc_ranks) == 0: # empty group
self.proc_ranks = None


+ 6
- 6
imperative/python/megengine/functional/nn.py View File

@@ -108,7 +108,7 @@ def conv2d(
"""
2D convolution operation.

Refer to :class:`~.Conv2d` for more information.
Refer to :class:`~.module.Conv2d` for more information.

:param inp: feature map of the convolution operation.
:param weight: convolution kernel.
@@ -1046,9 +1046,9 @@ def warp_affine(

.. note::

Here all available options for params are listed,
however it does not mean that you can use all the combinations.
On different platforms, different combinations are supported.
Here all available options for params are listed,
however it does not mean that you can use all the combinations.
On different platforms, different combinations are supported.
"""
op = builtin.WarpAffine(
border_mode=border_mode, border_val=border_val, format=format, imode=imode
@@ -1088,9 +1088,9 @@ def warp_perspective(
Default: "LINEAR". Currently only support "LINEAR" mode.
:return: output tensor.

Note:
.. note::

The transformation matrix is the inverse of that used by `cv2.warpPerspective`.
The transformation matrix is the inverse of that used by `cv2.warpPerspective`.

Examples:



+ 1
- 1
imperative/python/megengine/module/batch_matmul_activation.py View File

@@ -15,7 +15,7 @@ from .module import Module

class BatchMatMulActivation(Module):
r"""
Batched MatMul with activation(only relu supported), no transpose anywhere.
Batched :func:`~.matmul` with activation(only :func:`~.relu` supported), no transpose anywhere.
"""

def __init__(


+ 2
- 2
imperative/python/megengine/module/concat.py View File

@@ -14,8 +14,8 @@ from .module import Module

class Concat(Module):
r"""
A :class:`~.Module` to do functional concat. Could be replaced with :class:`~.QATModule`
version :class:`~.qat.concat.Concat` using :func:`~.quantize.quantize_qat`.
A :class:`~.Module` to do functional :func:`~.concat`. Could be replaced with :class:`~.QATModule`
version :class:`~.qat.Concat` using :func:`~.quantize.quantize_qat`.
"""

def forward(self, inps: Iterable[Tensor], axis: int = 0):


+ 6
- 7
imperative/python/megengine/module/conv.py View File

@@ -100,7 +100,7 @@ class Conv1d(_ConvNd):

For instance, given an input of the size :math:`(N, C_{\text{in}}, H)`,
this layer generates an output of the size
:math:`(N, C_{\text{out}}, H_{\text{out}}})` through the
:math:`(N, C_{\text{out}}, H_{\text{out}})` through the
process described as below:

.. math::
@@ -130,7 +130,7 @@ class Conv1d(_ConvNd):
spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: dilation of the 1D convolution operation. Default: 1
:param groups: number of groups into which the input and output channels are divided,
so as to perform a "grouped convolution". When ``groups`` is not 1,
so as to perform a "grouped convolution". When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and there would be an extra dimension at the beginning of the weight's
shape. Specifically, the shape of weight would be `(groups,
@@ -290,7 +290,7 @@ class Conv2d(_ConvNd):
spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: dilation of the 2D convolution operation. Default: 1
:param groups: number of groups into which the input and output channels are divided,
so as to perform a "grouped convolution". When ``groups`` is not 1,
so as to perform a "grouped convolution". When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and there would be an extra dimension at the beginning of the weight's
shape. Specifically, the shape of weight would be `(groups,
@@ -422,7 +422,7 @@ class ConvTranspose2d(_ConvNd):
spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: dilation of the 2D convolution operation. Default: 1
:param groups: number of groups into which the input and output channels are divided,
so as to perform a "grouped convolution". When ``groups`` is not 1,
so as to perform a "grouped convolution". When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and there would be an extra dimension at the beginning of the weight's
shape. Specifically, the shape of weight would be ``(groups,
@@ -592,9 +592,8 @@ class LocalConv2d(Conv2d):

class ConvRelu2d(Conv2d):
r"""
A fused :class:`~.Module` including Conv2d and relu. Could be replaced
with :class:`~.QATModule` version :class:`~.qat.conv.ConvRelu2d` using
:func:`~.quantize.quantize_qat`.
A fused :class:`~.Module` including :class:`~.module.Conv2d` and :func:`~.relu`.
Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvRelu2d` using :func:`~.quantize.quantize_qat`.
"""

def forward(self, inp):


+ 4
- 5
imperative/python/megengine/module/conv_bn.py View File

@@ -51,8 +51,8 @@ class _ConvBnActivation2d(Module):

class ConvBn2d(_ConvBnActivation2d):
r"""
A fused :class:`~.Module` including Conv2d, BatchNorm2d. Could be replaced
with :class:`~.QATModule` version :class:`~.qat.conv_bn.ConvBn2d` using
A fused :class:`~.Module` including :class:`~.module.Conv2d` and :class:`~.module.BatchNorm2d`.
Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvBn2d` using
:func:`~.quantize.quantize_qat`.
"""

@@ -62,9 +62,8 @@ class ConvBn2d(_ConvBnActivation2d):

class ConvBnRelu2d(_ConvBnActivation2d):
r"""
A fused :class:`~.Module` including Conv2d, BatchNorm2d and relu. Could be replaced
with :class:`~.QATModule` version :class:`~.qat.conv_bn.ConvBnRelu2d` using
:func:`~.quantize.quantize_qat`.
A fused :class:`~.Module` including :class:`~.module.Conv2d`, :class:`~.module.BatchNorm2d` and :func:`~.relu`.
Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvBnRelu2d` using :func:`~.quantize.quantize_qat`.
"""

def forward(self, inp):


+ 2
- 2
imperative/python/megengine/module/elemwise.py View File

@@ -12,8 +12,8 @@ from .module import Module

class Elemwise(Module):
r"""
A :class:`~.Module` to do elemwise operator. Could be replaced with :class:`~.QATModule`
version :class:`~.qat.elemwise.Elemwise` using :func:`~.quantize.quantize_qat`.
A :class:`~.Module` to do :mod:`~.functional.elemwise` operator. Could be replaced with :class:`~.QATModule`
version :class:`~.qat.Elemwise` using :func:`~.quantize.quantize_qat`.

:param method: the elemwise method, support the following string.
It will do the normal elemwise operator for float.


+ 4
- 0
imperative/python/megengine/module/qat/batch_matmul_activation.py View File

@@ -12,6 +12,10 @@ from .module import QATModule


class BatchMatMulActivation(Float.BatchMatMulActivation, QATModule):
r"""
A :class:`~.QATModule` :class:`~.module.BatchMatMulActivation` with QAT support.
"""

def forward(self, inp):
w_qat = self.apply_quant_weight(self.weight)
b_qat = fake_quant_bias(self.bias, inp, w_qat)


+ 1
- 1
imperative/python/megengine/module/qat/concat.py View File

@@ -14,7 +14,7 @@ from .module import QATModule

class Concat(Float.Concat, QATModule):
r"""
A :class:`~.QATModule` to do functional concat with QAT support.
A :class:`~.QATModule` to do functional :func:`~.concat` with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
"""



+ 2
- 2
imperative/python/megengine/module/qat/conv.py View File

@@ -13,7 +13,7 @@ from .module import QATModule

class Conv2d(Float.Conv2d, QATModule):
r"""
A :class:`~.QATModule` Conv2d with QAT support.
A :class:`~.QATModule` :class:`~.module.Conv2d` with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
"""

@@ -54,7 +54,7 @@ class Conv2d(Float.Conv2d, QATModule):

class ConvRelu2d(Conv2d):
r"""
A :class:`~.QATModule` include Conv2d and Relu with QAT support.
A :class:`~.QATModule` include :class:`~.module.Conv2d` and :func:`~.relu` with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
"""



+ 2
- 2
imperative/python/megengine/module/qat/conv_bn.py View File

@@ -164,7 +164,7 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule):

class ConvBn2d(_ConvBnActivation2d):
r"""
A fused :class:`~.QATModule` including Conv2d, BatchNorm2d with QAT support.
A fused :class:`~.QATModule` including :class:`~.module.Conv2d` and :class:`~.module.BatchNorm2d` with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
"""

@@ -174,7 +174,7 @@ class ConvBn2d(_ConvBnActivation2d):

class ConvBnRelu2d(_ConvBnActivation2d):
r"""
A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu with QAT support.
A fused :class:`~.QATModule` including :class:`~.module.Conv2d`, :class:`~.module.BatchNorm2d` and :func:`~.relu` with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
"""



+ 2
- 2
imperative/python/megengine/module/qat/elemwise.py View File

@@ -11,10 +11,10 @@ from .module import QATModule

class Elemwise(Float.Elemwise, QATModule):
r"""
A :class:`~.QATModule` to do elemwise operator with QAT support.
A :class:`~.QATModule` to do :mod:`~.functional.elemwise` operator with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.

:param method: the elemwise method, see :class:`~.module.elemwise.Elemwise` for detail.
:param method: the elemwise method, see :class:`~.module.Elemwise` for detail.
"""

with_weight = False


+ 1
- 1
imperative/python/megengine/module/qat/linear.py View File

@@ -12,7 +12,7 @@ from .module import QATModule

class Linear(Float.Linear, QATModule):
r"""
A :class:`~.QATModule` version of :class:`~.module.linear.Linear`.
A :class:`~.QATModule` version of :class:`~.module.Linear`.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.

:param in_features: size of each input sample.


+ 2
- 2
imperative/python/megengine/module/qat/module.py View File

@@ -14,9 +14,9 @@ from ..module import Module

class QATModule(Module):
r"""
Base class of quantized-float related Module, basically for QAT and Calibration.
Base class of quantized-float related :class:`~.Module`, basically for QAT and Calibration.

Use :meth:`~.QATModule.from_float_module` to generate a instance from float :class:`~.Module`.
Use :meth:`from_float_module` to generate a instance from float :class:`~.Module`.
Or use :func:`~.quantize.quantize_qat` to do it recursively and automatically.

Can also be converted to :class:`~.QuantizedModule` for deployment using


+ 2
- 2
imperative/python/megengine/module/qat/quant_dequant.py View File

@@ -11,7 +11,7 @@ from .module import QATModule

class QuantStub(Float.QuantStub, QATModule):
r"""
A helper QATModule simply return input, but will quantize
A helper :class:`~.QATModule` simply return input, but will quantize
input after converted to :class:`~.QuantizedModule`.
"""

@@ -31,7 +31,7 @@ class QuantStub(Float.QuantStub, QATModule):

class DequantStub(Float.DequantStub, QATModule):
r"""
A helper QATModule simply return input, but will de-quantize
A helper :class:`~.QATModule` simply return input, but will de-quantize
input after converted to :class:`~.QuantizedModule`.
"""



+ 2
- 0
imperative/python/megengine/module/quantized/batch_matmul_activation.py View File

@@ -19,6 +19,8 @@ from .module import QuantizedModule


class BatchMatMulActivation(Float.BatchMatMulActivation, QuantizedModule):
r"""Quantized version of :class:`~.qat.BatchMatMulActivation`."""

def __init__(
self,
batch: int,


+ 1
- 1
imperative/python/megengine/module/quantized/concat.py View File

@@ -15,7 +15,7 @@ from .module import QuantizedModule

class Concat(QuantizedModule):
r"""
A :class:`~.QuantizedModule` to do quantized concat, used for inference only.
A :class:`~.QuantizedModule` to do quantized :func:`~.concat`, used for inference only.
"""

def __init__(self, dtype=None):


+ 4
- 4
imperative/python/megengine/module/quantized/conv.py View File

@@ -18,11 +18,11 @@ from .module import QuantizedModule


class Conv2d(Float.Conv2d, QuantizedModule):
r"""Quantized version of :class:`~.qat.conv.Conv2d`."""
r"""
r"""Quantized version of :class:`~.qat.Conv2d`.
Applies a 2D convolution over a quantized input tensor, used for inference only.

The parameter is same with :class: `~.Conv2d`.
The parameter is same with :class:`~.module.Conv2d`.
"""

def __init__(
@@ -102,7 +102,7 @@ class Conv2d(Float.Conv2d, QuantizedModule):


class ConvRelu2d(Conv2d):
r"""Quantized version of :class:`~.qat.conv.ConvRelu2d`."""
r"""Quantized version of :class:`~.qat.ConvRelu2d`."""

def forward(self, inp):
return self.calc_conv_quantized(inp, nonlinear_mode="RELU")

+ 3
- 3
imperative/python/megengine/module/quantized/conv_bn.py View File

@@ -14,7 +14,7 @@ class _ConvBnActivation2d(Conv2d):
r"""
Applies a 2D convolution over a quantized input tensor, used for inference only.

The parameter is same with :class: `~.Conv2d`.
The parameter is same with :class: `~.module.Conv2d`.
"""

@classmethod
@@ -44,14 +44,14 @@ class _ConvBnActivation2d(Conv2d):


class ConvBn2d(_ConvBnActivation2d):
r"""Quantized version of :class:`~.qat.conv_bn.ConvBn2d`."""
r"""Quantized version of :class:`~.qat.ConvBn2d`."""

def forward(self, inp):
return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY")


class ConvBnRelu2d(_ConvBnActivation2d):
r"""Quantized version of :class:`~.qat.conv_bn.ConvBnRelu2d`."""
r"""Quantized version of :class:`~.qat.ConvBnRelu2d`."""

def forward(self, inp):
return self.calc_conv_quantized(inp, nonlinear_mode="RELU")

+ 1
- 1
imperative/python/megengine/module/quantized/elemwise.py View File

@@ -12,7 +12,7 @@ from .module import QuantizedModule


class Elemwise(QuantizedModule):
r"""Quantized version of :class:`~.qat.elemwise.Elemwise`."""
r"""Quantized version of :class:`~.qat.Elemwise`."""

def __init__(self, method, dtype=None):
super().__init__()


+ 1
- 1
imperative/python/megengine/module/quantized/linear.py View File

@@ -15,7 +15,7 @@ from .module import QuantizedModule


class Linear(QuantizedModule):
r"""Quantized version of :class:`~.qat.linear.Linear`."""
r"""Quantized version of :class:`~.qat.Linear`."""

def __init__(self, dtype: np.dtype = None):
super().__init__()


+ 2
- 2
imperative/python/megengine/module/quantized/module.py View File

@@ -13,8 +13,8 @@ from ..qat import QATModule

class QuantizedModule(Module):
r"""
Base class of quantized Module, which should be converted from QATModule
and not support traning.
Base class of quantized :class:`~.Module`,
which should be converted from :class:`~.QATModule` and not support traning.
"""

def __call__(self, *inputs, **kwargs):


+ 2
- 2
imperative/python/megengine/module/quantized/quant_dequant.py View File

@@ -11,7 +11,7 @@ from .module import QuantizedModule

class QuantStub(QuantizedModule):
r"""
Quantized version of :class:`~.qat.quant_dequant.QuantStub`,
Quantized version of :class:`~.qat.QuantStub`,
will convert input to quantized dtype.
"""

@@ -33,7 +33,7 @@ class QuantStub(QuantizedModule):

class DequantStub(QuantizedModule):
r"""
Quantized version of :class:`~.qat.quant_dequant.DequantStub`,
Quantized version of :class:`~.qat.DequantStub`,
will restore quantized input to float32 dtype.
"""



+ 40
- 0
imperative/python/megengine/tensor.py View File

@@ -24,6 +24,10 @@ from .utils.naming import auto_naming


class Tensor(_Tensor, ArrayMethodMixin):
r"""
A tensor object represents a multidimensional, homogeneous array of fixed-size items.
"""

grad = None
dmap_callback = None
_q_dict = None
@@ -59,6 +63,20 @@ class Tensor(_Tensor, ArrayMethodMixin):

@property
def shape(self) -> Union[tuple, "Tensor"]:
r"""
Returns a :class:`tuple` or a :class:`~.Tensor` represents tensor dimensions.

.. note::
The shape of a tensor was usually represented by a :class:`tuple`.
But if a tensor was treated as symbolic placeholder with tracing,
it's shape could also be a :class:`~.Tensor`. See :class:`~.trace` for more details.

The shape property is usually used to get the current shape of a tensor,
but may also be used to reshape the tensor in-place by assigning a tuple of tensor dimensions to it.
As with :func:`~.reshape`, one of the new shape dimensions can be -1,
in which case its value is inferred from the size of the tensor and the remaining dimensions.
"""
shape = super().shape
if shape == () or not use_symbolic_shape():
return shape
@@ -69,7 +87,17 @@ class Tensor(_Tensor, ArrayMethodMixin):
return super().shape

@property
def device(self) -> CompNode:
r"""
Returns a string represents the device a :class:`~.Tensor` storaged on.
"""
return super().device

@property
def dtype(self) -> np.dtype:
r"""
Returns a :class:`numpy.dtype` object represents the data type of a :class:`~.Tensor`.
"""
return super().dtype

@property
@@ -79,8 +107,17 @@ class Tensor(_Tensor, ArrayMethodMixin):
return self._q_dict

def numpy(self) -> np.ndarray:
r"""
Returns self :class:`~.Tensor` as a :class:`numpy.ndarray`.
"""
return super().numpy()

def detach(self):
r"""
Returns a new :class:`~.Tensor`, detached from the current graph.
"""
return super().detach()

def _reset(self, other):
super()._reset(other)

@@ -113,6 +150,9 @@ class Tensor(_Tensor, ArrayMethodMixin):
self *= 0

def to(self, device):
r"""
Copy self :class:`~.Tensor` to specified device. See :func:`~.copy`
"""
if isinstance(device, str) and not _valid_device(device):
raise ValueError(
"invalid device name {}. For the correct format of the device name, please refer to the instruction of megengine.device.set_default_device()".format(


Loading…
Cancel
Save