Browse Source

handle mutable input

pull/1690/head
chenyemeng 4 years ago
parent
commit
c7c651f948
1 changed files with 24 additions and 17 deletions
  1. +24
    -17
      ge/graph/optimize/mem_rw_conflict_optimize.cc

+ 24
- 17
ge/graph/optimize/mem_rw_conflict_optimize.cc View File

@@ -33,6 +33,8 @@ const int kCaseScopeWriteable = 2;
const int kCaseWriteable = 3; const int kCaseWriteable = 3;
const int kCaseInvalidRWType = 5; const int kCaseInvalidRWType = 5;


const char *const kInputMutable = "_input_mutable";

// rw type of input. // rw type of input.
enum class InputRWType { enum class InputRWType {
kReadOnly, // Normal op input only read kReadOnly, // Normal op input only read
@@ -634,24 +636,29 @@ Status InsertIdentityAsNeeded(const NodePtr &node) {
} }
Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) { Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) {
for (const auto &node : compute_graph->GetDirectNode()) { for (const auto &node : compute_graph->GetDirectNode()) {
if (node->GetType() == HCOMALLREDUCE) {
std::set<OutDataAnchorPtr> pre_out_anchor_set;
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(pre_out_anchor);
if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) {
pre_out_anchor_set.emplace(pre_out_anchor);
continue;
}
// need insert identity
auto pre_node = pre_out_anchor->GetOwnerNode();
auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx());
GE_CHECK_NOTNULL(identity_node);
auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node);
GE_CHK_STATUS_RET(ret, "Fail to insert identity.");
GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(),
pre_node->GetName().c_str(), node->GetName().c_str());
// op_desc of node should not be null
const auto &op_desc = node->GetOpDesc();
bool mutable_input_flag = false;
if (!AttrUtils::GetBool(op_desc, kInputMutable, mutable_input_flag) || !mutable_input_flag) {
GELOGD("[Node:%s] Input is not mutable, ignore memory conflict handle", op_desc->GetName().c_str());
continue;
}
std::set<OutDataAnchorPtr> pre_out_anchor_set;
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(pre_out_anchor);
if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) {
pre_out_anchor_set.emplace(pre_out_anchor);
continue;
} }
// need insert identity
auto pre_node = pre_out_anchor->GetOwnerNode();
auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx());
GE_CHECK_NOTNULL(identity_node);
auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node);
GE_CHK_STATUS_RET(ret, "Fail to insert identity.");
GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(),
pre_node->GetName().c_str(), node->GetName().c_str());
} }
} }
return SUCCESS; return SUCCESS;


Loading…
Cancel
Save