Browse Source

fix(mge/functional): fix one_hot irregular coding style

From GITHUB 399

    ORIGINAL_AUTHOR=Asthestarsfalll <1186454801@qq.com>
    COPYBARA_INTEGRATE_REVIEW=https://github.com/MegEngine/MegEngine/pull/399 from Asthestarsfalll:master 541bf3af29
    GITHUB_PUBLIC_PR_NUMBER=399
    GITHUB_PR_URL=https://github.com/MegEngine/MegEngine/pull/399

GitOrigin-RevId: 5df007207a
tags/v1.7.2.m1
Megvii Engine Team XindaH 3 years ago
parent
commit
7cb3ad8a3e
1 changed files with 4 additions and 2 deletions
  1. +4
    -2
      imperative/python/megengine/functional/nn.py

+ 4
- 2
imperative/python/megengine/functional/nn.py View File

@@ -1587,8 +1587,10 @@ def one_hot(inp: Tensor, num_classes: int) -> Tensor:
[0 0 1 0]
[0 0 0 1]]
"""
zeros_tensor = zeros(list(inp.shape) + [num_classes], inp.dtype, inp.device)
ones_tensor = ones(list(inp.shape) + [1], inp.dtype, inp.device)
zeros_tensor = zeros(
list(inp.shape) + [num_classes], dtype=inp.dtype, device=inp.device
)
ones_tensor = ones(list(inp.shape) + [1], dtype=inp.dtype, device=inp.device)

op = builtin.IndexingSetOneHot(axis=inp.ndim)
(result,) = apply(op, zeros_tensor, inp, ones_tensor)


Loading…
Cancel
Save