From 4500faf1eaa39a57e9c3a4de1e80e7b10445f8eb Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 30 Apr 2020 10:10:53 +0800 Subject: [PATCH] fix(python_module/mge): one_hot: no default value for num_classes GitOrigin-RevId: c4a53108806ddfb3e37ef7faa5d88afa31718d74 --- python_module/megengine/functional/nn.py | 4 +--- python_module/test/unit/functional/test_onehot.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/python_module/megengine/functional/nn.py b/python_module/megengine/functional/nn.py index b7a8140f..93fd66b7 100644 --- a/python_module/megengine/functional/nn.py +++ b/python_module/megengine/functional/nn.py @@ -427,7 +427,7 @@ def batch_norm2d( return output -def one_hot(inp: Tensor, num_classes: int = -1) -> Tensor: +def one_hot(inp: Tensor, num_classes: int) -> Tensor: r""" Perform one-hot encoding for the input tensor. @@ -457,8 +457,6 @@ def one_hot(inp: Tensor, num_classes: int = -1) -> Tensor: """ comp_node, comp_graph = _decide_comp_node_and_comp_graph(inp) - if num_classes == -1: - num_classes = inp.max() + 1 zeros = mgb.make_immutable(value=0, comp_node=comp_node, comp_graph=comp_graph) zeros_symvar = zeros.broadcast(inp.shapeof(), num_classes) diff --git a/python_module/test/unit/functional/test_onehot.py b/python_module/test/unit/functional/test_onehot.py index 808323c1..3edbe5de 100644 --- a/python_module/test/unit/functional/test_onehot.py +++ b/python_module/test/unit/functional/test_onehot.py @@ -8,7 +8,7 @@ from megengine.test import assertTensorClose def test_onehot_low_dimension(): inp = tensor(np.arange(1, 4, dtype=np.int32)) - out = F.one_hot(inp) + out = F.one_hot(inp, num_classes=4) assertTensorClose( out.numpy(), np.eye(4, dtype=np.int32)[np.arange(1, 4, dtype=np.int32)]