diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index ebcdbbcb..5324067c 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -69,6 +69,7 @@ __all__ = [ "leaky_relu", "linear", "local_conv2d", + "local_response_norm", "logsigmoid", "logsumexp", "logsoftmax", @@ -1746,6 +1747,53 @@ def pad( return output +def local_response_norm( + inp: Tensor, + kernel_size: int = 5, + k: float = 2.0, + alpha: float = 1e-4, + beta: float = 0.75, +) -> Tensor: + r""" + Apply local response normalization to the input tensor. + + Args: + kernel_size: the size of the kernel to apply LRN on. + k: hyperparameter k. The default vaule is 2.0. + alpha: hyperparameter alpha. The default value is 1e-4. + beta: hyperparameter beta. The default value is 0.75. + + Example: + + .. testcode:: + + from megengine import tensor + import megengine.functional as f + import numpy as np + + inp = tensor(np.arange(25, dtype=np.float32).reshape(1,1,5,5)) + GT = np.array([[[[ 0., 0.999925, 1.9994003, 2.9979765, 3.9952066], + [ 4.9906454, 5.983851, 6.974385, 7.961814, 8.945709 ], + [ 9.925651, 10.90122, 11.872011, 12.837625, 13.7976675], + [14.751757, 15.699524, 16.640602, 17.574642, 18.501305 ], + [19.420258, 20.331186, 21.233786, 22.127764, 23.012836 ]]]]) + + out = f.local_response_norm(inp, kernel_size=3, k=1.0, alpha=1e-4, beta=0.75) + np.testing.assert_allclose(GT, out.numpy(), rtol=1e-6, atol=1e-6) + print('pass') + + Outputs: + + .. testoutput:: + + pass + + """ + op = builtin.LRN(n=kernel_size, k=k, alpha=alpha, beta=beta,) + (output,) = apply(op, inp) + return output + + @lru_cache(maxsize=None) def _get_layerPixelShuffle(device, dtype, dim_order): @subgraph("LayerPixelShuffle", dtype, device, 3) diff --git a/imperative/python/megengine/module/__init__.py b/imperative/python/megengine/module/__init__.py index 7a5ad13e..3e127fa7 100644 --- a/imperative/python/megengine/module/__init__.py +++ b/imperative/python/megengine/module/__init__.py @@ -29,6 +29,7 @@ from .elemwise import Elemwise from .embedding import Embedding from .identity import Identity from .linear import Linear +from .lrn import LocalResponseNorm from .module import Module from .normalization import GroupNorm, InstanceNorm, LayerNorm from .padding import Pad diff --git a/imperative/python/megengine/module/lrn.py b/imperative/python/megengine/module/lrn.py new file mode 100644 index 00000000..05f6ba18 --- /dev/null +++ b/imperative/python/megengine/module/lrn.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from typing import Tuple, Union + +from ..functional import local_response_norm +from .module import Module + + +class LocalResponseNorm(Module): + r""" + Apply local response normalization to the input tensor. + + Args: + kernel_size: the size of the kernel to apply LRN on. + k: hyperparameter k. The default vaule is 2.0. + alpha: hyperparameter alpha. The default value is 1e-4. + beta: hyperparameter beta. The default value is 0.75. + + Example: + + .. testcode:: + + from megengine import tensor + import megengine.module as M + import numpy as np + + inp = tensor(np.arange(25, dtype=np.float32).reshape(1,1,5,5)) + GT = np.array([[[[ 0., 0.999925, 1.9994003, 2.9979765, 3.9952066], + [ 4.9906454, 5.983851, 6.974385, 7.961814, 8.945709 ], + [ 9.925651, 10.90122, 11.872011, 12.837625, 13.7976675], + [14.751757, 15.699524, 16.640602, 17.574642, 18.501305 ], + [19.420258, 20.331186, 21.233786, 22.127764, 23.012836 ]]]]) + + op = M.LocalResponseNorm(kernel_size=3, k=1.0, alpha=1e-4, beta=0.75) + out = op(inp) + np.testing.assert_allclose(GT, out.numpy(), rtol=1e-6, atol=1e-6) + print('pass') + + + Outputs: + + .. testoutput:: + + pass + + """ + + def __init__( + self, + kernel_size: int = 5, + k: float = 2.0, + alpha: float = 1e-4, + beta: float = 0.75, + **kwargs + ): + super(LocalResponseNorm, self).__init__(**kwargs) + self.kernel_size = kernel_size + self.k = k + self.alpha = alpha + self.beta = beta + + def forward(self, inp): + return local_response_norm(inp, self.kernel_size, self.k, self.alpha, self.beta) diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp index 3d596d3e..2e10c9a9 100644 --- a/imperative/src/impl/ops/specializations.cpp +++ b/imperative/src/impl/ops/specializations.cpp @@ -21,6 +21,7 @@ #include "megbrain/opr/dnn/fake_quant.h" #include "megbrain/opr/dnn/images2neibs.h" #include "megbrain/opr/dnn/local.h" +#include "megbrain/opr/dnn/lrn.h" #include "megbrain/opr/dnn/lsq.h" #include "megbrain/opr/dnn/pooling.h" #include "megbrain/opr/dnn/roi_align.h" @@ -654,4 +655,13 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { } OP_TRAIT_REG(Padding, Padding).apply_on_var_node(apply_on_var_node).fallback(); } // namespace padding + +namespace lrn { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 1); + return opr::LRN::make(inputs[0], op.param()); +} +OP_TRAIT_REG(LRN, LRN).apply_on_var_node(apply_on_var_node).fallback(); +} // namespace LRN } // namespace mgb::imperative diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index f4d873e4..d500f532 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -422,4 +422,6 @@ def Split: MgbHashableOp<"Split", [EmptyParam]> { def Padding: MgbHashableOp<"Padding", [PaddingParam]>; +def LRN: MgbHashableOp<"LRN", [LRNParam]>; + #endif // MGB_OPS