You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

concat.py 592 B

123456789101112131415161718
  1. from typing import Iterable
  2. from ...tensor import Tensor
  3. from .. import concat as Float
  4. from .module import QATModule
  5. class Concat(Float.Concat, QATModule):
  6. r"""A :class:`~.QATModule` to do functional :func:`~.concat` with QAT support.
  7. Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`.
  8. """
  9. def forward(self, inps: Iterable[Tensor], axis: int = 0):
  10. return self.apply_quant_activation(super().forward(inps, axis))
  11. @classmethod
  12. def from_float_module(cls, float_module):
  13. return cls(name=float_module.name)