GitOrigin-RevId: 939a4d26dd
release-1.7
@@ -69,6 +69,7 @@ __all__ = [ | |||
"leaky_relu", | |||
"linear", | |||
"local_conv2d", | |||
"local_response_norm", | |||
"logsigmoid", | |||
"logsumexp", | |||
"logsoftmax", | |||
@@ -1746,6 +1747,53 @@ def pad( | |||
return output | |||
def local_response_norm( | |||
inp: Tensor, | |||
kernel_size: int = 5, | |||
k: float = 2.0, | |||
alpha: float = 1e-4, | |||
beta: float = 0.75, | |||
) -> Tensor: | |||
r""" | |||
Apply local response normalization to the input tensor. | |||
Args: | |||
kernel_size: the size of the kernel to apply LRN on. | |||
k: hyperparameter k. The default vaule is 2.0. | |||
alpha: hyperparameter alpha. The default value is 1e-4. | |||
beta: hyperparameter beta. The default value is 0.75. | |||
Example: | |||
.. testcode:: | |||
from megengine import tensor | |||
import megengine.functional as f | |||
import numpy as np | |||
inp = tensor(np.arange(25, dtype=np.float32).reshape(1,1,5,5)) | |||
GT = np.array([[[[ 0., 0.999925, 1.9994003, 2.9979765, 3.9952066], | |||
[ 4.9906454, 5.983851, 6.974385, 7.961814, 8.945709 ], | |||
[ 9.925651, 10.90122, 11.872011, 12.837625, 13.7976675], | |||
[14.751757, 15.699524, 16.640602, 17.574642, 18.501305 ], | |||
[19.420258, 20.331186, 21.233786, 22.127764, 23.012836 ]]]]) | |||
out = f.local_response_norm(inp, kernel_size=3, k=1.0, alpha=1e-4, beta=0.75) | |||
np.testing.assert_allclose(GT, out.numpy(), rtol=1e-6, atol=1e-6) | |||
print('pass') | |||
Outputs: | |||
.. testoutput:: | |||
pass | |||
""" | |||
op = builtin.LRN(n=kernel_size, k=k, alpha=alpha, beta=beta,) | |||
(output,) = apply(op, inp) | |||
return output | |||
@lru_cache(maxsize=None) | |||
def _get_layerPixelShuffle(device, dtype, dim_order): | |||
@subgraph("LayerPixelShuffle", dtype, device, 3) | |||
@@ -29,6 +29,7 @@ from .elemwise import Elemwise | |||
from .embedding import Embedding | |||
from .identity import Identity | |||
from .linear import Linear | |||
from .lrn import LocalResponseNorm | |||
from .module import Module | |||
from .normalization import GroupNorm, InstanceNorm, LayerNorm | |||
from .padding import Pad | |||
@@ -0,0 +1,69 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Tuple, Union | |||
from ..functional import local_response_norm | |||
from .module import Module | |||
class LocalResponseNorm(Module): | |||
r""" | |||
Apply local response normalization to the input tensor. | |||
Args: | |||
kernel_size: the size of the kernel to apply LRN on. | |||
k: hyperparameter k. The default vaule is 2.0. | |||
alpha: hyperparameter alpha. The default value is 1e-4. | |||
beta: hyperparameter beta. The default value is 0.75. | |||
Example: | |||
.. testcode:: | |||
from megengine import tensor | |||
import megengine.module as M | |||
import numpy as np | |||
inp = tensor(np.arange(25, dtype=np.float32).reshape(1,1,5,5)) | |||
GT = np.array([[[[ 0., 0.999925, 1.9994003, 2.9979765, 3.9952066], | |||
[ 4.9906454, 5.983851, 6.974385, 7.961814, 8.945709 ], | |||
[ 9.925651, 10.90122, 11.872011, 12.837625, 13.7976675], | |||
[14.751757, 15.699524, 16.640602, 17.574642, 18.501305 ], | |||
[19.420258, 20.331186, 21.233786, 22.127764, 23.012836 ]]]]) | |||
op = M.LocalResponseNorm(kernel_size=3, k=1.0, alpha=1e-4, beta=0.75) | |||
out = op(inp) | |||
np.testing.assert_allclose(GT, out.numpy(), rtol=1e-6, atol=1e-6) | |||
print('pass') | |||
Outputs: | |||
.. testoutput:: | |||
pass | |||
""" | |||
def __init__( | |||
self, | |||
kernel_size: int = 5, | |||
k: float = 2.0, | |||
alpha: float = 1e-4, | |||
beta: float = 0.75, | |||
**kwargs | |||
): | |||
super(LocalResponseNorm, self).__init__(**kwargs) | |||
self.kernel_size = kernel_size | |||
self.k = k | |||
self.alpha = alpha | |||
self.beta = beta | |||
def forward(self, inp): | |||
return local_response_norm(inp, self.kernel_size, self.k, self.alpha, self.beta) |
@@ -21,6 +21,7 @@ | |||
#include "megbrain/opr/dnn/fake_quant.h" | |||
#include "megbrain/opr/dnn/images2neibs.h" | |||
#include "megbrain/opr/dnn/local.h" | |||
#include "megbrain/opr/dnn/lrn.h" | |||
#include "megbrain/opr/dnn/lsq.h" | |||
#include "megbrain/opr/dnn/pooling.h" | |||
#include "megbrain/opr/dnn/roi_align.h" | |||
@@ -654,4 +655,13 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
} | |||
OP_TRAIT_REG(Padding, Padding).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace padding | |||
namespace lrn { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const LRN&>(def); | |||
mgb_assert(inputs.size() == 1); | |||
return opr::LRN::make(inputs[0], op.param()); | |||
} | |||
OP_TRAIT_REG(LRN, LRN).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace LRN | |||
} // namespace mgb::imperative |
@@ -422,4 +422,6 @@ def Split: MgbHashableOp<"Split", [EmptyParam]> { | |||
def Padding: MgbHashableOp<"Padding", [PaddingParam]>; | |||
def LRN: MgbHashableOp<"LRN", [LRNParam]>; | |||
#endif // MGB_OPS |