Browse Source

fix

pull/277/head
wangxiaotian22 4 years ago
parent
commit
634f87da5d
2 changed files with 14 additions and 11 deletions
  1. +6
    -6
      ge/graph/build/memory/block_mem_assigner.cc
  2. +8
    -5
      ge/graph/passes/hccl_memcpy_pass.cc

+ 6
- 6
ge/graph/build/memory/block_mem_assigner.cc View File

@@ -826,6 +826,12 @@ bool BlockMemAssigner::IsContinuousOutput(const NodePtr &n) {
return false;
}

// if output size just one, no need to reassign continuous memory
if (node_op_desc->GetOutputsSize() == 1) {
GELOGI("op %s output size is one, no need to continuous process.", n->GetName().c_str());
return false;
}

// Get the continuous output type of the node, default is false
bool is_output_continuous = false;
auto node_desc = n->GetOpDesc();
@@ -939,12 +945,6 @@ MemoryBlock *BlockMemAssigner::ApplyContinuousMemory(const NodePtr &n, const vec
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(n == nullptr, return nullptr, "input node is null.");
auto node_op_desc = n->GetOpDesc();
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(node_op_desc == nullptr, return nullptr, "node_op_desc is null.");
// if output size just one, no need to reassign continuous memory
if (node_op_desc->GetOutputsSize() == 1) {
zero_memory_list_.emplace_back(n, kOutput, 0);
return nullptr;
}

MemoryBlock *block = nullptr;
int64_t total_size = 0;


+ 8
- 5
ge/graph/passes/hccl_memcpy_pass.cc View File

@@ -58,7 +58,7 @@ Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) {
GELOGE(INTERNAL_ERROR, "failed P2pmemInputProcess, node_name:%s.", node->GetName().c_str());
return ret;
}
}
return ret;
}
@@ -66,7 +66,7 @@ Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) {
// If node has _input_mutable attr, means input mem may be modified when op execute.
// In order to avoid to affect another op execute with same input when data modified,
// need to inset memcpy node between.
// also works on situation that input is variable or const.
// also works on situation that input is variable or const.
Status HcclMemcpyPass::MutableInputProcess(const ComputeGraphPtr &graph, const NodePtr node) {
auto op_desc = node->GetOpDesc();

@@ -77,7 +77,7 @@ Status HcclMemcpyPass::MutableInputProcess(const ComputeGraphPtr &graph, const N

if (!AttrUtils::GetBool(op_desc, kInputMutable, node_input_mutable)) {
GELOGE(INTERNAL_ERROR, "node:%s get attr:_input_mutable failed.", node->GetName().c_str());
return FAILED;
return FAILED;
}
if (!node_input_mutable) {
return SUCCESS;
@@ -126,6 +126,7 @@ Status HcclMemcpyPass::ContinuousInputProcess(const ComputeGraphPtr &graph, cons
(void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_CONTINUOUS_INPUT, is_input_continuous);

if (is_input_continuous && op_desc->GetInputsSize() > 1) {
GELOGI("continuous input op is:%s.", op_desc->GetName().c_str());
// if input size bigger than one, insert memcpy between var data for support continous mem alloc
for (auto &hccl_in_anchor : node->GetAllInDataAnchors()) {
if (hccl_in_anchor == nullptr) {
@@ -136,7 +137,7 @@ Status HcclMemcpyPass::ContinuousInputProcess(const ComputeGraphPtr &graph, cons
GELOGE(INTERNAL_ERROR, "hcom op input has no peer anchor, node_name:%s", node->GetName().c_str());
return INTERNAL_ERROR;
}
if (IsDataNode(src_out_anchor->GetOwnerNode()->GetType())) {
Status ret = ModifyEdgeConnection(graph, src_out_anchor, hccl_in_anchor);
if (ret != SUCCESS) {
@@ -165,6 +166,7 @@ Status HcclMemcpyPass::P2pmemInputProcess(const ComputeGraphPtr &graph, const No
continue;
}

GELOGI("p2p input op is:%s.", op_desc->GetName().c_str());
auto hccl_in_anchor = node->GetInDataAnchor(index);
if (hccl_in_anchor == nullptr) {
continue;
@@ -263,7 +265,8 @@ std::string HcclMemcpyPass::CheckDuplicateName(const std::string &node_name) {
///
Status HcclMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &graph, const OutDataAnchorPtr &src_out_anchor,
const InDataAnchorPtr &hccl_in_anchor) {
GELOGI("The op %s need insert memcpy async op.", src_out_anchor->GetOwnerNode()->GetName().c_str());
GELOGI("Between op %s and op %s need insert memcpy async op.", src_out_anchor->GetOwnerNode()->GetName().c_str(),
hccl_in_anchor->GetOwnerNode()->GetName().c_str());
NodePtr memcpy_node = CreateIdentityNode(graph, src_out_anchor);
GE_CHECK_NOTNULL(memcpy_node);



Loading…
Cancel
Save