Browse Source

fix all reduce and loss overlap

pull/315/head
baker 4 years ago
parent
commit
8a35d811ca
2 changed files with 20 additions and 7 deletions
  1. +3
    -7
      ge/graph/build/logical_stream_allocator.cc
  2. +17
    -0
      ge/graph/build/logical_stream_allocator.h

+ 3
- 7
ge/graph/build/logical_stream_allocator.cc View File

@@ -363,13 +363,10 @@ Status NodeStreamUpdatePass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr
}
}

// Update stream id for nodes belong to skipped engine subgraph
GE_CHK_STATUS_RET(UpdateForSkippedEngine(graph, subgraphs));

return SUCCESS;
}

int64_t NodeStreamUpdatePass::GetSingleInoutStream(const NodePtr &node) const {
int64_t UpdateForSkippedEnginePass::GetSingleInoutStream(const NodePtr &node) const {
set<int64_t> stream_ids;

for (const auto &in_node : node->GetInAllNodes()) {
@@ -398,8 +395,7 @@ int64_t NodeStreamUpdatePass::GetSingleInoutStream(const NodePtr &node) const {
return kInvalidStream;
}

Status NodeStreamUpdatePass::UpdateForSkippedEngine(const ComputeGraphPtr &graph,
const vector<SubgraphPtr> &subgraphs) {
Status UpdateForSkippedEnginePass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr> &subgraphs, Context &context) {
set<OpDescPtr> ops_without_label;

// Check if subgraph is engine skipped and without stream label or not
@@ -441,7 +437,7 @@ Status NodeStreamUpdatePass::UpdateForSkippedEngine(const ComputeGraphPtr &graph
return SUCCESS;
}

bool NodeStreamUpdatePass::AreAllPredStreamsInvalid(const NodePtr &node) const {
bool UpdateForSkippedEnginePass::AreAllPredStreamsInvalid(const NodePtr &node) const {
for (const auto &pre_node : node->GetInAllNodes()) {
auto pre_node_desc = pre_node->GetOpDesc();
if (pre_node_desc != nullptr) {


+ 17
- 0
ge/graph/build/logical_stream_allocator.h View File

@@ -161,6 +161,23 @@ class NodeStreamUpdatePass : public LogicalStreamPass {
bool AreAllPredStreamsInvalid(const NodePtr &node) const;
};

// Update the stream of subgraphs to nodes.
class UpdateForSkippedEnginePass : public LogicalStreamPass {
public:
STREAM_PASS_DEFAULT_FUNC(UpdateForSkippedEnginePass);
/// Optimize for case like:
/// NodeA(stream1) -> Const(stream2) -> NodeB(stream1)
/// To case:
/// NodeA(stream1) -> Const(stream1) -> NodeB(stream1)
/// Which could reduce event number (Const could be other type which belong to skipped engine subgraph)
Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override;

private:
int64_t GetSingleInoutStream(const NodePtr &node) const;
// Judge if all predecessors' streams of node are kInvalidStream
bool AreAllPredStreamsInvalid(const NodePtr &node) const;
};

// AllReduce and backward operators execute in parallel.
class AllReduceParallelPass : public LogicalStreamPass {
public:


Loading…
Cancel
Save