diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index c127cc57..9ababceb 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1587,8 +1587,10 @@ def one_hot(inp: Tensor, num_classes: int) -> Tensor: [0 0 1 0] [0 0 0 1]] """ - zeros_tensor = zeros(list(inp.shape) + [num_classes], inp.dtype, inp.device) - ones_tensor = ones(list(inp.shape) + [1], inp.dtype, inp.device) + zeros_tensor = zeros( + list(inp.shape) + [num_classes], dtype=inp.dtype, device=inp.device + ) + ones_tensor = ones(list(inp.shape) + [1], dtype=inp.dtype, device=inp.device) op = builtin.IndexingSetOneHot(axis=inp.ndim) (result,) = apply(op, zeros_tensor, inp, ones_tensor)