|
@@ -802,7 +802,7 @@ void CollectiveComm::init_output_static_infer_desc() { |
|
|
if (m_param.mode == Param::Mode::SCATTER) { |
|
|
if (m_param.mode == Param::Mode::SCATTER) { |
|
|
dest[0] /= nr_devices(); |
|
|
dest[0] /= nr_devices(); |
|
|
} |
|
|
} |
|
|
if (!m_output_shape.valid()) { |
|
|
|
|
|
|
|
|
if (is_root() && !m_output_shape.valid()) { |
|
|
m_output_shape = dest; |
|
|
m_output_shape = dest; |
|
|
m_group_client->set_output_shape(m_key, dest); |
|
|
m_group_client->set_output_shape(m_key, dest); |
|
|
} |
|
|
} |
|
@@ -824,7 +824,7 @@ void CollectiveComm::init_output_static_infer_desc() { |
|
|
|
|
|
|
|
|
mgb_assert(output().size() == 1); |
|
|
mgb_assert(output().size() == 1); |
|
|
|
|
|
|
|
|
if (is_root()) { |
|
|
|
|
|
|
|
|
if (is_root() || input().size() > 0) { |
|
|
mgb_assert(input().size() == 1); |
|
|
mgb_assert(input().size() == 1); |
|
|
mgr.register_shape_infer(output(0), |
|
|
mgr.register_shape_infer(output(0), |
|
|
{SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_shape_from_input}); |
|
|
{SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_shape_from_input}); |
|
|