diff --git a/python_module/megengine/functional/nn.py b/python_module/megengine/functional/nn.py index b7a8140f..93fd66b7 100644 --- a/python_module/megengine/functional/nn.py +++ b/python_module/megengine/functional/nn.py @@ -427,7 +427,7 @@ def batch_norm2d( return output -def one_hot(inp: Tensor, num_classes: int = -1) -> Tensor: +def one_hot(inp: Tensor, num_classes: int) -> Tensor: r""" Perform one-hot encoding for the input tensor. @@ -457,8 +457,6 @@ def one_hot(inp: Tensor, num_classes: int = -1) -> Tensor: """ comp_node, comp_graph = _decide_comp_node_and_comp_graph(inp) - if num_classes == -1: - num_classes = inp.max() + 1 zeros = mgb.make_immutable(value=0, comp_node=comp_node, comp_graph=comp_graph) zeros_symvar = zeros.broadcast(inp.shapeof(), num_classes) diff --git a/python_module/test/unit/functional/test_onehot.py b/python_module/test/unit/functional/test_onehot.py index 808323c1..3edbe5de 100644 --- a/python_module/test/unit/functional/test_onehot.py +++ b/python_module/test/unit/functional/test_onehot.py @@ -8,7 +8,7 @@ from megengine.test import assertTensorClose def test_onehot_low_dimension(): inp = tensor(np.arange(1, 4, dtype=np.int32)) - out = F.one_hot(inp) + out = F.one_hot(inp, num_classes=4) assertTensorClose( out.numpy(), np.eye(4, dtype=np.int32)[np.arange(1, 4, dtype=np.int32)]