From 8a35d811ca206150061a88b1183b722087d4ba5b Mon Sep 17 00:00:00 2001 From: baker Date: Tue, 17 Nov 2020 22:07:15 +0800 Subject: [PATCH] fix all reduce and loss overlap --- ge/graph/build/logical_stream_allocator.cc | 10 +++------- ge/graph/build/logical_stream_allocator.h | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/ge/graph/build/logical_stream_allocator.cc b/ge/graph/build/logical_stream_allocator.cc index 5c8bc46c..7ee65f94 100644 --- a/ge/graph/build/logical_stream_allocator.cc +++ b/ge/graph/build/logical_stream_allocator.cc @@ -363,13 +363,10 @@ Status NodeStreamUpdatePass::Run(ComputeGraphPtr graph, const vector stream_ids; for (const auto &in_node : node->GetInAllNodes()) { @@ -398,8 +395,7 @@ int64_t NodeStreamUpdatePass::GetSingleInoutStream(const NodePtr &node) const { return kInvalidStream; } -Status NodeStreamUpdatePass::UpdateForSkippedEngine(const ComputeGraphPtr &graph, - const vector &subgraphs) { +Status UpdateForSkippedEnginePass::Run(ComputeGraphPtr graph, const vector &subgraphs, Context &context) { set ops_without_label; // Check if subgraph is engine skipped and without stream label or not @@ -441,7 +437,7 @@ Status NodeStreamUpdatePass::UpdateForSkippedEngine(const ComputeGraphPtr &graph return SUCCESS; } -bool NodeStreamUpdatePass::AreAllPredStreamsInvalid(const NodePtr &node) const { +bool UpdateForSkippedEnginePass::AreAllPredStreamsInvalid(const NodePtr &node) const { for (const auto &pre_node : node->GetInAllNodes()) { auto pre_node_desc = pre_node->GetOpDesc(); if (pre_node_desc != nullptr) { diff --git a/ge/graph/build/logical_stream_allocator.h b/ge/graph/build/logical_stream_allocator.h index 0aebb9b4..46d4c44d 100644 --- a/ge/graph/build/logical_stream_allocator.h +++ b/ge/graph/build/logical_stream_allocator.h @@ -161,6 +161,23 @@ class NodeStreamUpdatePass : public LogicalStreamPass { bool AreAllPredStreamsInvalid(const NodePtr &node) const; }; +// Update the stream of subgraphs to nodes. +class UpdateForSkippedEnginePass : public LogicalStreamPass { + public: + STREAM_PASS_DEFAULT_FUNC(UpdateForSkippedEnginePass); + /// Optimize for case like: + /// NodeA(stream1) -> Const(stream2) -> NodeB(stream1) + /// To case: + /// NodeA(stream1) -> Const(stream1) -> NodeB(stream1) + /// Which could reduce event number (Const could be other type which belong to skipped engine subgraph) + Status Run(ComputeGraphPtr graph, const std::vector &subgraphs, Context &context) override; + + private: + int64_t GetSingleInoutStream(const NodePtr &node) const; + // Judge if all predecessors' streams of node are kInvalidStream + bool AreAllPredStreamsInvalid(const NodePtr &node) const; +}; + // AllReduce and backward operators execute in parallel. class AllReduceParallelPass : public LogicalStreamPass { public: