GitOrigin-RevId: 41ea88dfa1
release-1.1
@@ -109,11 +109,8 @@ void SeqCompNodeOptimizerImpl::change_to_specific_stream( | |||||
type = any_strong_changed ? | type = any_strong_changed ? | ||||
StreamPropType::STRONG : StreamPropType::WEAK; | StreamPropType::STRONG : StreamPropType::WEAK; | ||||
int copy_stream = CompNode::Stream::COPY; | int copy_stream = CompNode::Stream::COPY; | ||||
int nccl_stream = CompNode::Stream::NCCL; | |||||
if (inp_streams.count(copy_stream)) | if (inp_streams.count(copy_stream)) | ||||
stream = copy_stream; | stream = copy_stream; | ||||
else if (inp_streams.count(nccl_stream)) | |||||
stream = nccl_stream; | |||||
mgb_assert(type != StreamPropType::NONE && stream != 0); | mgb_assert(type != StreamPropType::NONE && stream != 0); | ||||
} | } | ||||
return prop_type_storage.second = StreamPropType{stream, type}; | return prop_type_storage.second = StreamPropType{stream, type}; | ||||
@@ -188,8 +185,7 @@ void SeqCompNodeOptimizerImpl::register_stream_var( | |||||
mgb_assert(var->owner_graph() == m_owner_graph && | mgb_assert(var->owner_graph() == m_owner_graph && | ||||
(prop_type == StreamPropType::WEAK || | (prop_type == StreamPropType::WEAK || | ||||
prop_type == StreamPropType::STRONG)); | prop_type == StreamPropType::STRONG)); | ||||
mgb_assert(stream == CompNode::Stream::COPY || stream == | |||||
CompNode::Stream::NCCL); | |||||
mgb_assert(stream == CompNode::Stream::COPY); | |||||
auto ins = m_var2prop_type.insert({var, {stream, prop_type}}); | auto ins = m_var2prop_type.insert({var, {stream, prop_type}}); | ||||
if (!ins.second) { | if (!ins.second) { | ||||
@@ -207,8 +207,7 @@ class CompNode { | |||||
static constexpr int | static constexpr int | ||||
COPY = -1, | COPY = -1, | ||||
REMOTE_SEND = -2, | REMOTE_SEND = -2, | ||||
LOOP_SWAP = -3, | |||||
NCCL = -4; | |||||
LOOP_SWAP = -3; | |||||
}; | }; | ||||
CompNode() = default; | CompNode() = default; | ||||
@@ -630,11 +630,7 @@ void CollectiveComm::get_output_var_shape(const TensorShapeArray& inp_shape, | |||||
inp_shape, out_shape); | inp_shape, out_shape); | ||||
} | } | ||||
void CollectiveComm::init_output_comp_node() { | |||||
mgb_assert(output().size() == 1, "exactly one output expected, got %zu", output().size()); | |||||
owner_graph()->seq_comp_node_optimizer().register_stream_var(output()[0], | |||||
{CompNode::Stream::NCCL, cg::SeqCompNodeOptimizer::StreamPropType::WEAK}); | |||||
} | |||||
void CollectiveComm::init_output_comp_node() {} | |||||
void CollectiveComm::init_output_mem_plan(bool dynamic) { | void CollectiveComm::init_output_mem_plan(bool dynamic) { | ||||
for (size_t i = 0; i < output().size(); i++) { | for (size_t i = 0; i < output().size(); i++) { | ||||
@@ -269,13 +269,13 @@ TEST(TestOprBasicArith, AddUpdateOtherStream) { | |||||
}; | }; | ||||
std::shared_ptr<HostTensorND> host_val = gen({SIZE}); | std::shared_ptr<HostTensorND> host_val = gen({SIZE}); | ||||
auto cn_nccl = CompNode::load("gpu0").change_stream(CompNode::Stream::NCCL); | |||||
auto cn1 = CompNode::load("gpu0:0").change_stream(1); | |||||
auto param = opr::SharedDeviceTensor::make(*graph, *host_val); | auto param = opr::SharedDeviceTensor::make(*graph, *host_val); | ||||
param.node()->owner_opr()->node_prop().attribute().priority = | param.node()->owner_opr()->node_prop().attribute().priority = | ||||
std::numeric_limits<int>::max(); | std::numeric_limits<int>::max(); | ||||
auto copy = opr::Copy::make(param, cn_nccl); | |||||
auto copy = opr::Copy::make(param, cn1); | |||||
auto add = (copy + 3) * 5; | auto add = (copy + 3) * 5; | ||||
auto add_update = opr::AddUpdate::make(param, add, {}, {cn_nccl}); | |||||
auto add_update = opr::AddUpdate::make(param, add, {}, {cn1}); | |||||
auto callback = opr::CallbackInjector::make(add_update, set_flag); | auto callback = opr::CallbackInjector::make(add_update, set_flag); | ||||