GitOrigin-RevId: 939a4d26dd
release-1.7
@@ -69,6 +69,7 @@ __all__ = [ | |||||
"leaky_relu", | "leaky_relu", | ||||
"linear", | "linear", | ||||
"local_conv2d", | "local_conv2d", | ||||
"local_response_norm", | |||||
"logsigmoid", | "logsigmoid", | ||||
"logsumexp", | "logsumexp", | ||||
"logsoftmax", | "logsoftmax", | ||||
@@ -1746,6 +1747,53 @@ def pad( | |||||
return output | return output | ||||
def local_response_norm( | |||||
inp: Tensor, | |||||
kernel_size: int = 5, | |||||
k: float = 2.0, | |||||
alpha: float = 1e-4, | |||||
beta: float = 0.75, | |||||
) -> Tensor: | |||||
r""" | |||||
Apply local response normalization to the input tensor. | |||||
Args: | |||||
kernel_size: the size of the kernel to apply LRN on. | |||||
k: hyperparameter k. The default vaule is 2.0. | |||||
alpha: hyperparameter alpha. The default value is 1e-4. | |||||
beta: hyperparameter beta. The default value is 0.75. | |||||
Example: | |||||
.. testcode:: | |||||
from megengine import tensor | |||||
import megengine.functional as f | |||||
import numpy as np | |||||
inp = tensor(np.arange(25, dtype=np.float32).reshape(1,1,5,5)) | |||||
GT = np.array([[[[ 0., 0.999925, 1.9994003, 2.9979765, 3.9952066], | |||||
[ 4.9906454, 5.983851, 6.974385, 7.961814, 8.945709 ], | |||||
[ 9.925651, 10.90122, 11.872011, 12.837625, 13.7976675], | |||||
[14.751757, 15.699524, 16.640602, 17.574642, 18.501305 ], | |||||
[19.420258, 20.331186, 21.233786, 22.127764, 23.012836 ]]]]) | |||||
out = f.local_response_norm(inp, kernel_size=3, k=1.0, alpha=1e-4, beta=0.75) | |||||
np.testing.assert_allclose(GT, out.numpy(), rtol=1e-6, atol=1e-6) | |||||
print('pass') | |||||
Outputs: | |||||
.. testoutput:: | |||||
pass | |||||
""" | |||||
op = builtin.LRN(n=kernel_size, k=k, alpha=alpha, beta=beta,) | |||||
(output,) = apply(op, inp) | |||||
return output | |||||
@lru_cache(maxsize=None) | @lru_cache(maxsize=None) | ||||
def _get_layerPixelShuffle(device, dtype, dim_order): | def _get_layerPixelShuffle(device, dtype, dim_order): | ||||
@subgraph("LayerPixelShuffle", dtype, device, 3) | @subgraph("LayerPixelShuffle", dtype, device, 3) | ||||
@@ -29,6 +29,7 @@ from .elemwise import Elemwise | |||||
from .embedding import Embedding | from .embedding import Embedding | ||||
from .identity import Identity | from .identity import Identity | ||||
from .linear import Linear | from .linear import Linear | ||||
from .lrn import LocalResponseNorm | |||||
from .module import Module | from .module import Module | ||||
from .normalization import GroupNorm, InstanceNorm, LayerNorm | from .normalization import GroupNorm, InstanceNorm, LayerNorm | ||||
from .padding import Pad | from .padding import Pad | ||||
@@ -0,0 +1,69 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from typing import Tuple, Union | |||||
from ..functional import local_response_norm | |||||
from .module import Module | |||||
class LocalResponseNorm(Module): | |||||
r""" | |||||
Apply local response normalization to the input tensor. | |||||
Args: | |||||
kernel_size: the size of the kernel to apply LRN on. | |||||
k: hyperparameter k. The default vaule is 2.0. | |||||
alpha: hyperparameter alpha. The default value is 1e-4. | |||||
beta: hyperparameter beta. The default value is 0.75. | |||||
Example: | |||||
.. testcode:: | |||||
from megengine import tensor | |||||
import megengine.module as M | |||||
import numpy as np | |||||
inp = tensor(np.arange(25, dtype=np.float32).reshape(1,1,5,5)) | |||||
GT = np.array([[[[ 0., 0.999925, 1.9994003, 2.9979765, 3.9952066], | |||||
[ 4.9906454, 5.983851, 6.974385, 7.961814, 8.945709 ], | |||||
[ 9.925651, 10.90122, 11.872011, 12.837625, 13.7976675], | |||||
[14.751757, 15.699524, 16.640602, 17.574642, 18.501305 ], | |||||
[19.420258, 20.331186, 21.233786, 22.127764, 23.012836 ]]]]) | |||||
op = M.LocalResponseNorm(kernel_size=3, k=1.0, alpha=1e-4, beta=0.75) | |||||
out = op(inp) | |||||
np.testing.assert_allclose(GT, out.numpy(), rtol=1e-6, atol=1e-6) | |||||
print('pass') | |||||
Outputs: | |||||
.. testoutput:: | |||||
pass | |||||
""" | |||||
def __init__( | |||||
self, | |||||
kernel_size: int = 5, | |||||
k: float = 2.0, | |||||
alpha: float = 1e-4, | |||||
beta: float = 0.75, | |||||
**kwargs | |||||
): | |||||
super(LocalResponseNorm, self).__init__(**kwargs) | |||||
self.kernel_size = kernel_size | |||||
self.k = k | |||||
self.alpha = alpha | |||||
self.beta = beta | |||||
def forward(self, inp): | |||||
return local_response_norm(inp, self.kernel_size, self.k, self.alpha, self.beta) |
@@ -21,6 +21,7 @@ | |||||
#include "megbrain/opr/dnn/fake_quant.h" | #include "megbrain/opr/dnn/fake_quant.h" | ||||
#include "megbrain/opr/dnn/images2neibs.h" | #include "megbrain/opr/dnn/images2neibs.h" | ||||
#include "megbrain/opr/dnn/local.h" | #include "megbrain/opr/dnn/local.h" | ||||
#include "megbrain/opr/dnn/lrn.h" | |||||
#include "megbrain/opr/dnn/lsq.h" | #include "megbrain/opr/dnn/lsq.h" | ||||
#include "megbrain/opr/dnn/pooling.h" | #include "megbrain/opr/dnn/pooling.h" | ||||
#include "megbrain/opr/dnn/roi_align.h" | #include "megbrain/opr/dnn/roi_align.h" | ||||
@@ -654,4 +655,13 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
} | } | ||||
OP_TRAIT_REG(Padding, Padding).apply_on_var_node(apply_on_var_node).fallback(); | OP_TRAIT_REG(Padding, Padding).apply_on_var_node(apply_on_var_node).fallback(); | ||||
} // namespace padding | } // namespace padding | ||||
namespace lrn { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const LRN&>(def); | |||||
mgb_assert(inputs.size() == 1); | |||||
return opr::LRN::make(inputs[0], op.param()); | |||||
} | |||||
OP_TRAIT_REG(LRN, LRN).apply_on_var_node(apply_on_var_node).fallback(); | |||||
} // namespace LRN | |||||
} // namespace mgb::imperative | } // namespace mgb::imperative |
@@ -422,4 +422,6 @@ def Split: MgbHashableOp<"Split", [EmptyParam]> { | |||||
def Padding: MgbHashableOp<"Padding", [PaddingParam]>; | def Padding: MgbHashableOp<"Padding", [PaddingParam]>; | ||||
def LRN: MgbHashableOp<"LRN", [LRNParam]>; | |||||
#endif // MGB_OPS | #endif // MGB_OPS |