Browse Source

fix(mgb/opr-mm): remove stream -4 for CollectiveComm

GitOrigin-RevId: 41ea88dfa1
release-1.1
Megvii Engine Team 4 years ago
parent
commit
4a5e317063
4 changed files with 6 additions and 15 deletions
  1. +1
    -5
      src/core/impl/graph/seq_comp_node_opt_impl.cpp
  2. +1
    -2
      src/core/include/megbrain/comp_node.h
  3. +1
    -5
      src/opr-mm/impl/collective_comm.cpp
  4. +3
    -3
      src/opr/test/basic_arith/others.cpp

+ 1
- 5
src/core/impl/graph/seq_comp_node_opt_impl.cpp View File

@@ -109,11 +109,8 @@ void SeqCompNodeOptimizerImpl::change_to_specific_stream(
type = any_strong_changed ?
StreamPropType::STRONG : StreamPropType::WEAK;
int copy_stream = CompNode::Stream::COPY;
int nccl_stream = CompNode::Stream::NCCL;
if (inp_streams.count(copy_stream))
stream = copy_stream;
else if (inp_streams.count(nccl_stream))
stream = nccl_stream;
mgb_assert(type != StreamPropType::NONE && stream != 0);
}
return prop_type_storage.second = StreamPropType{stream, type};
@@ -188,8 +185,7 @@ void SeqCompNodeOptimizerImpl::register_stream_var(
mgb_assert(var->owner_graph() == m_owner_graph &&
(prop_type == StreamPropType::WEAK ||
prop_type == StreamPropType::STRONG));
mgb_assert(stream == CompNode::Stream::COPY || stream ==
CompNode::Stream::NCCL);
mgb_assert(stream == CompNode::Stream::COPY);

auto ins = m_var2prop_type.insert({var, {stream, prop_type}});
if (!ins.second) {


+ 1
- 2
src/core/include/megbrain/comp_node.h View File

@@ -207,8 +207,7 @@ class CompNode {
static constexpr int
COPY = -1,
REMOTE_SEND = -2,
LOOP_SWAP = -3,
NCCL = -4;
LOOP_SWAP = -3;
};

CompNode() = default;


+ 1
- 5
src/opr-mm/impl/collective_comm.cpp View File

@@ -630,11 +630,7 @@ void CollectiveComm::get_output_var_shape(const TensorShapeArray& inp_shape,
inp_shape, out_shape);
}

void CollectiveComm::init_output_comp_node() {
mgb_assert(output().size() == 1, "exactly one output expected, got %zu", output().size());
owner_graph()->seq_comp_node_optimizer().register_stream_var(output()[0],
{CompNode::Stream::NCCL, cg::SeqCompNodeOptimizer::StreamPropType::WEAK});
}
void CollectiveComm::init_output_comp_node() {}

void CollectiveComm::init_output_mem_plan(bool dynamic) {
for (size_t i = 0; i < output().size(); i++) {


+ 3
- 3
src/opr/test/basic_arith/others.cpp View File

@@ -269,13 +269,13 @@ TEST(TestOprBasicArith, AddUpdateOtherStream) {
};

std::shared_ptr<HostTensorND> host_val = gen({SIZE});
auto cn_nccl = CompNode::load("gpu0").change_stream(CompNode::Stream::NCCL);
auto cn1 = CompNode::load("gpu0:0").change_stream(1);
auto param = opr::SharedDeviceTensor::make(*graph, *host_val);
param.node()->owner_opr()->node_prop().attribute().priority =
std::numeric_limits<int>::max();
auto copy = opr::Copy::make(param, cn_nccl);
auto copy = opr::Copy::make(param, cn1);
auto add = (copy + 3) * 5;
auto add_update = opr::AddUpdate::make(param, add, {}, {cn_nccl});
auto add_update = opr::AddUpdate::make(param, add, {}, {cn1});

auto callback = opr::CallbackInjector::make(add_update, set_flag);



Loading…
Cancel
Save