diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index d1cc7286..b7ff90cf 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1708,7 +1708,7 @@ def sliding_window_transpose( def pad( src: Tensor, - pad_witdth: Tuple[Tuple[int, int], ...], + pad_width: Tuple[Tuple[int, int], ...], mode: str = "constant", constant_value: float = 0.0, ) -> Tensor: @@ -1723,9 +1723,9 @@ def pad( if mode.lower() == "edge": mode = "replicate" - for i in range(0, len(pad_witdth)): - p_offsets[i * 2] = pad_witdth[i][0] - p_offsets[i * 2 + 1] = pad_witdth[i][1] + for i in range(0, len(pad_width)): + p_offsets[i * 2] = pad_width[i][0] + p_offsets[i * 2 + 1] = pad_width[i][1] op = builtin.Padding( front_offset_dim0=p_offsets[0], diff --git a/imperative/python/megengine/module/padding.py b/imperative/python/megengine/module/padding.py index 433ecd32..078daabc 100644 --- a/imperative/python/megengine/module/padding.py +++ b/imperative/python/megengine/module/padding.py @@ -12,16 +12,16 @@ class Pad(Module): def __init__( self, - pad_witdth: Tuple[Tuple[int, int], ...], + pad_width: Tuple[Tuple[int, int], ...], mode: str = "constant", constant_val: float = 0.0, ): super().__init__() - self.pad_width = pad_witdth + self.pad_width = pad_width self.mode = mode self.pad_val = constant_val def forward(self, src): return nn.pad( - src, pad_witdth=self.pad_width, mode=self.mode, constant_value=self.pad_val + src, pad_width=self.pad_width, mode=self.mode, constant_value=self.pad_val ) diff --git a/imperative/python/megengine/traced_module/compat.py b/imperative/python/megengine/traced_module/compat.py index 9350b1bd..26d66cf2 100644 --- a/imperative/python/megengine/traced_module/compat.py +++ b/imperative/python/megengine/traced_module/compat.py @@ -162,3 +162,11 @@ def tensor_gen_func_loader(expr): else: device = None expr.set_args_kwargs(shape, dtype=dtype, device=device) + + +@register_functional_loader(("megengine.functional.nn", "pad")) +def pad_func_loader(expr): + if "pad_witdth" in expr.kwargs: + kwargs = expr.kwargs + kwargs["pad_width"] = kwargs.pop("pad_witdth") + expr.set_args_kwargs(*expr.args, **kwargs)