@@ -1708,7 +1708,7 @@ def sliding_window_transpose( | |||||
def pad( | def pad( | ||||
src: Tensor, | src: Tensor, | ||||
pad_witdth: Tuple[Tuple[int, int], ...], | |||||
pad_width: Tuple[Tuple[int, int], ...], | |||||
mode: str = "constant", | mode: str = "constant", | ||||
constant_value: float = 0.0, | constant_value: float = 0.0, | ||||
) -> Tensor: | ) -> Tensor: | ||||
@@ -1723,9 +1723,9 @@ def pad( | |||||
if mode.lower() == "edge": | if mode.lower() == "edge": | ||||
mode = "replicate" | mode = "replicate" | ||||
for i in range(0, len(pad_witdth)): | |||||
p_offsets[i * 2] = pad_witdth[i][0] | |||||
p_offsets[i * 2 + 1] = pad_witdth[i][1] | |||||
for i in range(0, len(pad_width)): | |||||
p_offsets[i * 2] = pad_width[i][0] | |||||
p_offsets[i * 2 + 1] = pad_width[i][1] | |||||
op = builtin.Padding( | op = builtin.Padding( | ||||
front_offset_dim0=p_offsets[0], | front_offset_dim0=p_offsets[0], | ||||
@@ -12,16 +12,16 @@ class Pad(Module): | |||||
def __init__( | def __init__( | ||||
self, | self, | ||||
pad_witdth: Tuple[Tuple[int, int], ...], | |||||
pad_width: Tuple[Tuple[int, int], ...], | |||||
mode: str = "constant", | mode: str = "constant", | ||||
constant_val: float = 0.0, | constant_val: float = 0.0, | ||||
): | ): | ||||
super().__init__() | super().__init__() | ||||
self.pad_width = pad_witdth | |||||
self.pad_width = pad_width | |||||
self.mode = mode | self.mode = mode | ||||
self.pad_val = constant_val | self.pad_val = constant_val | ||||
def forward(self, src): | def forward(self, src): | ||||
return nn.pad( | return nn.pad( | ||||
src, pad_witdth=self.pad_width, mode=self.mode, constant_value=self.pad_val | |||||
src, pad_width=self.pad_width, mode=self.mode, constant_value=self.pad_val | |||||
) | ) |
@@ -162,3 +162,11 @@ def tensor_gen_func_loader(expr): | |||||
else: | else: | ||||
device = None | device = None | ||||
expr.set_args_kwargs(shape, dtype=dtype, device=device) | expr.set_args_kwargs(shape, dtype=dtype, device=device) | ||||
@register_functional_loader(("megengine.functional.nn", "pad")) | |||||
def pad_func_loader(expr): | |||||
if "pad_witdth" in expr.kwargs: | |||||
kwargs = expr.kwargs | |||||
kwargs["pad_width"] = kwargs.pop("pad_witdth") | |||||
expr.set_args_kwargs(*expr.args, **kwargs) |