|
|
@@ -1587,8 +1587,8 @@ def one_hot(inp: Tensor, num_classes: int) -> Tensor: |
|
|
|
[0 0 1 0] |
|
|
|
[0 0 0 1]] |
|
|
|
""" |
|
|
|
zeros_tensor = zeros(list(inp.shape) + [num_classes], inp.dtype, inp.device) |
|
|
|
ones_tensor = ones(list(inp.shape) + [1], inp.dtype, inp.device) |
|
|
|
zeros_tensor = zeros(list(inp.shape) + [num_classes], dtype=inp.dtype, device=inp.device) |
|
|
|
ones_tensor = ones(list(inp.shape) + [1], dtype=inp.dtype, device=inp.device) |
|
|
|
|
|
|
|
op = builtin.IndexingSetOneHot(axis=inp.ndim) |
|
|
|
(result,) = apply(op, zeros_tensor, inp, ones_tensor) |
|
|
|