From ffcb4dacd9cd6199f7e1bc7ba517be973c838b3a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 13 Sep 2021 10:59:35 +0800 Subject: [PATCH] style(mge): change the pad opr doc and format GitOrigin-RevId: 7b20e965198c42e8bda92b42865ed649c25c06c2 --- imperative/python/megengine/functional/nn.py | 16 ++++------------ imperative/python/megengine/module/padding.py | 7 ++++++- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index e862c760..fdc9eb7e 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1595,24 +1595,16 @@ def sliding_window_transpose( def pad( src: Tensor, pad_witdth: Tuple[Tuple[int, int], ...], - mode: str = "CONSTANT", + mode: str = "constant", constant_value: float = 0.0, ) -> Tensor: """ - pad + Pad is python warpper for padding opr in megbrain, can padding in random one of the max 7 dimensions. + Supported constant, edge(replicate) and reflect mode, constatnt is the default mode. """ p_offsets = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - assert mode in [ - "constant", - "CONSTANT", - "edge", - "EDGE", - "replicate", - "REPLICATE", - "reflect", - "REFLECT", - ] + assert mode.lower() in ["constant", "edge", "replicate", "reflect"] if mode.lower() == "edge": mode = "replicate" diff --git a/imperative/python/megengine/module/padding.py b/imperative/python/megengine/module/padding.py index 0dcc3714..433ecd32 100644 --- a/imperative/python/megengine/module/padding.py +++ b/imperative/python/megengine/module/padding.py @@ -5,10 +5,15 @@ from .module import Module class Pad(Module): + """ + Pad is python warpper for padding opr in megbrain, can padding in random one of the max 7 dimensions. + Supported constant, edge(replicate) and reflect mode, constatnt is the default mode. + """ + def __init__( self, pad_witdth: Tuple[Tuple[int, int], ...], - mode: str = "CONSTANT", + mode: str = "constant", constant_val: float = 0.0, ): super().__init__()