You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lrn.py 2.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Tuple, Union
  10. from ..functional import local_response_norm
  11. from .module import Module
  12. class LocalResponseNorm(Module):
  13. r"""
  14. Apply local response normalization to the input tensor.
  15. Args:
  16. kernel_size: the size of the kernel to apply LRN on.
  17. k: hyperparameter k. The default vaule is 2.0.
  18. alpha: hyperparameter alpha. The default value is 1e-4.
  19. beta: hyperparameter beta. The default value is 0.75.
  20. Example:
  21. .. testcode::
  22. from megengine import tensor
  23. import megengine.module as M
  24. import numpy as np
  25. inp = tensor(np.arange(25, dtype=np.float32).reshape(1,1,5,5))
  26. GT = np.array([[[[ 0., 0.999925, 1.9994003, 2.9979765, 3.9952066],
  27. [ 4.9906454, 5.983851, 6.974385, 7.961814, 8.945709 ],
  28. [ 9.925651, 10.90122, 11.872011, 12.837625, 13.7976675],
  29. [14.751757, 15.699524, 16.640602, 17.574642, 18.501305 ],
  30. [19.420258, 20.331186, 21.233786, 22.127764, 23.012836 ]]]])
  31. op = M.LocalResponseNorm(kernel_size=3, k=1.0, alpha=1e-4, beta=0.75)
  32. out = op(inp)
  33. np.testing.assert_allclose(GT, out.numpy(), rtol=1e-6, atol=1e-6)
  34. print('pass')
  35. Outputs:
  36. .. testoutput::
  37. pass
  38. """
  39. def __init__(
  40. self,
  41. kernel_size: int = 5,
  42. k: float = 2.0,
  43. alpha: float = 1e-4,
  44. beta: float = 0.75,
  45. **kwargs
  46. ):
  47. super(LocalResponseNorm, self).__init__(**kwargs)
  48. self.kernel_size = kernel_size
  49. self.k = k
  50. self.alpha = alpha
  51. self.beta = beta
  52. def forward(self, inp):
  53. return local_response_norm(inp, self.kernel_size, self.k, self.alpha, self.beta)