Browse Source

test(mge/collective_comm): fix collective_comm test and add data parallel test

GitOrigin-RevId: 9209e77973
release-0.6
Megvii Engine Team 4 years ago
parent
commit
e3e981ccf0
1 changed files with 2 additions and 2 deletions
  1. +2
    -2
      src/opr-mm/impl/collective_comm.cpp

+ 2
- 2
src/opr-mm/impl/collective_comm.cpp View File

@@ -802,7 +802,7 @@ void CollectiveComm::init_output_static_infer_desc() {
if (m_param.mode == Param::Mode::SCATTER) {
dest[0] /= nr_devices();
}
if (!m_output_shape.valid()) {
if (is_root() && !m_output_shape.valid()) {
m_output_shape = dest;
m_group_client->set_output_shape(m_key, dest);
}
@@ -824,7 +824,7 @@ void CollectiveComm::init_output_static_infer_desc() {

mgb_assert(output().size() == 1);

if (is_root()) {
if (is_root() || input().size() > 0) {
mgb_assert(input().size() == 1);
mgr.register_shape_infer(output(0),
{SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_shape_from_input});


Loading…
Cancel
Save