@@ -42,7 +42,7 @@ Status MergeInputMemcpyPass::Run(ComputeGraphPtr graph) { | |||||
/// @return Status | /// @return Status | ||||
/// | /// | ||||
Status MergeInputMemcpyPass::AddMemcpyAsyncNodes(const ComputeGraphPtr &graph, const NodePtr &node, | Status MergeInputMemcpyPass::AddMemcpyAsyncNodes(const ComputeGraphPtr &graph, const NodePtr &node, | ||||
bool multi_batch_flag) { | |||||
bool multi_batch_flag) { | |||||
for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | ||||
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | ||||
GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); | GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); | ||||
@@ -74,7 +74,7 @@ Status MergeInputMemcpyPass::AddMemcpyAsyncNodes(const ComputeGraphPtr &graph, c | |||||
/// @return ge::NodePtr | /// @return ge::NodePtr | ||||
/// | /// | ||||
NodePtr MergeInputMemcpyPass::CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, const std::string &name, | NodePtr MergeInputMemcpyPass::CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, const std::string &name, | ||||
const OutDataAnchorPtr &out_data_anchor, bool multi_batch_flag) { | |||||
const OutDataAnchorPtr &out_data_anchor, bool multi_batch_flag) { | |||||
OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); | OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); | ||||
GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "OpDesc of pre node is invalid."); | GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "OpDesc of pre node is invalid."); | ||||
@@ -32,7 +32,7 @@ Status MergeToStreamMergePass::Run(ComputeGraphPtr graph) { | |||||
OpDescPtr merge_op_desc = node->GetOpDesc(); | OpDescPtr merge_op_desc = node->GetOpDesc(); | ||||
GE_CHECK_NOTNULL(merge_op_desc); | GE_CHECK_NOTNULL(merge_op_desc); | ||||
if (merge_op_desc->HasAttr(ATTR_INSERT_BY_MBATCH)) { | if (merge_op_desc->HasAttr(ATTR_INSERT_BY_MBATCH)) { | ||||
GE_CHK_STATUS_RET(AddActiveNodes(graph, node, true), "Merge add active node failed."); | |||||
GE_CHK_STATUS_RET(AddActiveNodes(graph, node), "Merge add active node failed."); | |||||
GE_CHK_STATUS_RET(SetStreamLabel(node, node->GetName()), "Set stream label failed"); | GE_CHK_STATUS_RET(SetStreamLabel(node, node->GetName()), "Set stream label failed"); | ||||
} else { | } else { | ||||
GE_CHK_STATUS_RET(ReplaceMergeNode(graph, node), "Add StreamMerge node failed."); | GE_CHK_STATUS_RET(ReplaceMergeNode(graph, node), "Add StreamMerge node failed."); | ||||
@@ -99,18 +99,16 @@ Status MergeToStreamMergePass::ReplaceMergeNode(const ComputeGraphPtr &graph, co | |||||
} | } | ||||
} | } | ||||
return AddActiveNodes(graph, stream_merge, false); | |||||
return AddActiveNodes(graph, stream_merge); | |||||
} | } | ||||
/// | /// | ||||
/// @brief Add StreamActive Op before StreamMerge/Merge | /// @brief Add StreamActive Op before StreamMerge/Merge | ||||
/// @param [in] graph | /// @param [in] graph | ||||
/// @param [in] node | /// @param [in] node | ||||
/// @param [in] multi_batch_flag | |||||
/// @return Status | /// @return Status | ||||
/// | /// | ||||
Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, const NodePtr &node, | |||||
bool multi_batch_flag) { | |||||
Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, const NodePtr &node) { | |||||
GE_CHK_BOOL_EXEC(node != nullptr, return FAILED, "Param of pre node is null."); | GE_CHK_BOOL_EXEC(node != nullptr, return FAILED, "Param of pre node is null."); | ||||
for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | ||||
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | ||||
@@ -37,10 +37,9 @@ class MergeToStreamMergePass : public GraphPass { | |||||
/// @brief Add StreamActive Op as StreamMerge in_node | /// @brief Add StreamActive Op as StreamMerge in_node | ||||
/// @param [in] graph | /// @param [in] graph | ||||
/// @param [in] node | /// @param [in] node | ||||
/// @param [in] multi_batch_flag | |||||
/// @return Status | /// @return Status | ||||
/// | /// | ||||
Status AddActiveNodes(const ComputeGraphPtr &graph, const NodePtr &node, bool multi_batch_flag); | |||||
Status AddActiveNodes(const ComputeGraphPtr &graph, const NodePtr &node); | |||||
/// | /// | ||||
/// @brief Create Active Op | /// @brief Create Active Op | ||||