Browse Source

fix(python_module/mge): one_hot: no default value for num_classes

GitOrigin-RevId: c4a5310880
tags/v0.4.0
Megvii Engine Team Xinran Xu 5 years ago
parent
commit
4500faf1ea
2 changed files with 2 additions and 4 deletions
  1. +1
    -3
      python_module/megengine/functional/nn.py
  2. +1
    -1
      python_module/test/unit/functional/test_onehot.py

+ 1
- 3
python_module/megengine/functional/nn.py View File

@@ -427,7 +427,7 @@ def batch_norm2d(
return output


def one_hot(inp: Tensor, num_classes: int = -1) -> Tensor:
def one_hot(inp: Tensor, num_classes: int) -> Tensor:
r"""
Perform one-hot encoding for the input tensor.

@@ -457,8 +457,6 @@ def one_hot(inp: Tensor, num_classes: int = -1) -> Tensor:
"""
comp_node, comp_graph = _decide_comp_node_and_comp_graph(inp)

if num_classes == -1:
num_classes = inp.max() + 1
zeros = mgb.make_immutable(value=0, comp_node=comp_node, comp_graph=comp_graph)
zeros_symvar = zeros.broadcast(inp.shapeof(), num_classes)



+ 1
- 1
python_module/test/unit/functional/test_onehot.py View File

@@ -8,7 +8,7 @@ from megengine.test import assertTensorClose

def test_onehot_low_dimension():
inp = tensor(np.arange(1, 4, dtype=np.int32))
out = F.one_hot(inp)
out = F.one_hot(inp, num_classes=4)

assertTensorClose(
out.numpy(), np.eye(4, dtype=np.int32)[np.arange(1, 4, dtype=np.int32)]


Loading…
Cancel
Save