|
|
@@ -271,6 +271,7 @@ def all_reduce_max( |
|
|
|
:param inp: input tensor. |
|
|
|
:param group: communication group. |
|
|
|
:param device: execution device. |
|
|
|
:returns: reduced tensor. |
|
|
|
""" |
|
|
|
mode = CollectiveComm.Mode.ALL_REDUCE_MAX |
|
|
|
return collective_comm(inp, mode, group, device) |
|
|
|