|
@@ -427,7 +427,7 @@ def batch_norm2d( |
|
|
return output |
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def one_hot(inp: Tensor, num_classes: int = -1) -> Tensor: |
|
|
|
|
|
|
|
|
def one_hot(inp: Tensor, num_classes: int) -> Tensor: |
|
|
r""" |
|
|
r""" |
|
|
Perform one-hot encoding for the input tensor. |
|
|
Perform one-hot encoding for the input tensor. |
|
|
|
|
|
|
|
@@ -457,8 +457,6 @@ def one_hot(inp: Tensor, num_classes: int = -1) -> Tensor: |
|
|
""" |
|
|
""" |
|
|
comp_node, comp_graph = _decide_comp_node_and_comp_graph(inp) |
|
|
comp_node, comp_graph = _decide_comp_node_and_comp_graph(inp) |
|
|
|
|
|
|
|
|
if num_classes == -1: |
|
|
|
|
|
num_classes = inp.max() + 1 |
|
|
|
|
|
zeros = mgb.make_immutable(value=0, comp_node=comp_node, comp_graph=comp_graph) |
|
|
zeros = mgb.make_immutable(value=0, comp_node=comp_node, comp_graph=comp_graph) |
|
|
zeros_symvar = zeros.broadcast(inp.shapeof(), num_classes) |
|
|
zeros_symvar = zeros.broadcast(inp.shapeof(), num_classes) |
|
|
|
|
|
|
|
|