Browse Source

fix

pull/277/head
wangxiaotian22 4 years ago
parent
commit
94f2185b9b
3 changed files with 117 additions and 47 deletions
  1. +1
    -1
      ge/graph/build/memory/block_mem_assigner.cc
  2. +109
    -45
      ge/graph/passes/hccl_memcpy_pass.cc
  3. +7
    -1
      ge/graph/passes/hccl_memcpy_pass.h

+ 1
- 1
ge/graph/build/memory/block_mem_assigner.cc View File

@@ -428,7 +428,7 @@ void BlockMemAssigner::GetOutAndWorkSpaceMem(vector<int64_t> &all_memory_size) {
// if input size just one, no need to reassign continuous memory
bool is_input_continuous = false;
(void)ge::AttrUtils::GetBool(node_op_desc, ATTR_NAME_CONTINUOUS_INPUT, is_input_continuous);
if (is_input_continuous && (node_op_desc->GetInputSize() <= 1)) {
if (is_input_continuous && (node_op_desc->GetInputsSize() <= 1)) {
(void)ge::AttrUtils::SetBool(node_op_desc, ATTR_NAME_CONTINUOUS_INPUT_ALLOC, true);
}



+ 109
- 45
ge/graph/passes/hccl_memcpy_pass.cc View File

@@ -32,75 +32,101 @@ const char *const kInputMutable = "_input_mutable";
} // namespace
namespace ge {
Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) {
Status ret = SUCCESS;
GE_IF_BOOL_EXEC(graph == nullptr, GELOGE(PARAM_INVALID, "param [graph] must not be null."); return PARAM_INVALID);
for (const auto &node : graph->GetDirectNode()) {
auto op_desc = node->GetOpDesc();
GE_IF_BOOL_EXEC(op_desc == nullptr, continue);
if (op_desc == nullptr) {
GELOGE(INTERNAL_ERROR, "node has no op_desc, node_name : %s.", node->GetName().c_str());
return INTERNAL_ERROR;
}

Status ret = ProcessBroadcastMemcpy(graph, node);
ret = ContinuousInputProcess(graph, node);
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "failed ProcessBroadcastMemcpy.");
GELOGE(INTERNAL_ERROR, "failed ProcessBroadcastMemcpy, node_name:%s.", node->GetName().c_str());
return ret;
}

bool node_input_mutable = false;
if (!AttrUtils::HasAttr(op_desc, kInputMutable)) {
continue;
ret = MutableInputProcess(graph, node);
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "failed MutableInputProcess, node_name:%s.", node->GetName().c_str());
return ret;
}

ret = P2pmemInputProcess(graph, node);
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "failed P2pmemInputProcess, node_name:%s.", node->GetName().c_str());
return ret;
}
}
return ret;
}

// If node has _input_mutable attr, means input mem may be modified when op execute.
// In order to avoid to affect another op execute with same input when data modified,
// need to inset memcpy node between.
// also works on situation that input is variable or const.
Status HcclMemcpyPass::MutableInputProcess(const ComputeGraphPtr &graph, const NodePtr node) {
auto op_desc = node->GetOpDesc();

bool node_input_mutable = false;
if (!AttrUtils::HasAttr(op_desc, kInputMutable)) {
return SUCCESS;
}

if (!AttrUtils::GetBool(op_desc, kInputMutable, node_input_mutable)) {
GELOGE(INTERNAL_ERROR, "node:%s get attr:_input_mutable failed.", node->GetName().c_str());
return FAILED;
}
if (!node_input_mutable) {
return SUCCESS;
}

GE_IF_BOOL_EXEC(!AttrUtils::GetBool(op_desc, kInputMutable, node_input_mutable),
GELOGE(INTERNAL_ERROR, "node:%s get attr:_input_mutable failed.", node->GetName().c_str()); return FAILED);
if (!node_input_mutable) {
GELOGI("input mutable hcom op is:%s.", op_desc->GetName().c_str());
for (auto &hccl_in_anchor : node->GetAllInDataAnchors()) {
if (hccl_in_anchor == nullptr) {
continue;
}
auto src_out_anchor = hccl_in_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(src_out_anchor);

GELOGI("hcom op is:%s.", op_desc->GetName().c_str());
for (auto &hccl_in_anchor : node->GetAllInDataAnchors()) {
if (hccl_in_anchor == nullptr) {
continue;
}
auto src_out_anchor = hccl_in_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(src_out_anchor);

int32_t src_out_anchor_size = src_out_anchor->GetPeerInDataAnchors().size();
if (src_out_anchor_size == kAnchorSize) {
// Memcpyasync needs to be inserted between constant (/data) and hcomallreduce to avoid constant being cleared.
NodePtr src_node = src_out_anchor->GetOwnerNode();
std::string src_type = src_node->GetType();
bool check_src_type = (src_type == CONSTANTOP) || (src_type == VARIABLE) || (src_type == DATA) || (src_type == CONSTANT);
if (check_src_type) {
Status ret = ModifyEdgeConnection(graph, src_out_anchor, hccl_in_anchor);
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Failed to modify the connection.");
return ret;
}
int32_t src_out_anchor_size = src_out_anchor->GetPeerInDataAnchors().size();
if (src_out_anchor_size == kAnchorSize) {
// Memcpyasync needs to be inserted between constant (/data) and hcomallreduce to avoid constant being cleared.
if (IsDataNode(src_out_anchor->GetOwnerNode()->GetType())) {
Status ret = ModifyEdgeConnection(graph, src_out_anchor, hccl_in_anchor);
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Failed to modify the connection.");
return ret;
}
continue;
}
continue;
}

Status ret = ModifyEdgeConnection(graph, src_out_anchor, hccl_in_anchor);
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Failed to modify the connection.");
return ret;
}
Status ret = ModifyEdgeConnection(graph, src_out_anchor, hccl_in_anchor);
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Failed to modify the connection.");
return ret;
}
}
return SUCCESS;
}


// If broadcast input size is bigger than 1, and input from variable,
// cause by broadcast input memory should be continuous,
// another featuremap mem will be allocated for broadcast input.
// In this condition, move data from variable mem to broadcast input featuremap mem will be executed each step.
// In order to avoid move action out of model, use memcpy node instead of move action code.
Status HcclMemcpyPass::ProcessBroadcastMemcpy(const ComputeGraphPtr &graph, const NodePtr node) {
Status HcclMemcpyPass::ContinuousInputProcess(const ComputeGraphPtr &graph, const NodePtr node) {
auto op_desc = node->GetOpDesc();
if (op_desc == nullptr) {
GELOGE(INTERNAL_ERROR, "node has no op_desc, node_name : %s.", node->GetName().c_str());
return INTERNAL_ERROR;
}

if ((node->GetType() == HCOMBROADCAST || node->GetType() == HVDCALLBACKBROADCAST) && op_desc->GetInputSize() > 1) {
bool is_input_continuous = false;
(void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_CONTINUOUS_INPUT, is_input_continuous);

if (is_input_continuous && op_desc->GetInputsSize() > 1) {
// if input size bigger than one, insert memcpy between var data for support continous mem alloc
for (auto &hccl_in_anchor : node->GetAllInDataAnchors()) {
if (hccl_in_anchor == nullptr) {
continue;
@@ -111,10 +137,7 @@ Status HcclMemcpyPass::ProcessBroadcastMemcpy(const ComputeGraphPtr &graph, cons
return INTERNAL_ERROR;
}
NodePtr src_node = src_out_anchor->GetOwnerNode();
std::string src_type = src_node->GetType();
bool check_src_type = (src_type == CONSTANTOP) || (src_type == VARIABLE) || (src_type == DATA) || (src_type == CONSTANT);
if (check_src_type) {
if (IsDataNode(src_out_anchor->GetOwnerNode()->GetType())) {
Status ret = ModifyEdgeConnection(graph, src_out_anchor, hccl_in_anchor);
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Failed to modify the connection.");
@@ -126,6 +149,47 @@ Status HcclMemcpyPass::ProcessBroadcastMemcpy(const ComputeGraphPtr &graph, cons
return SUCCESS;
}

// if input is var type, and node input need p2p mem, then memcpy should be insert between the two
Status HcclMemcpyPass::P2pmemInputProcess(const ComputeGraphPtr &graph, const NodePtr node) {
auto op_desc = node->GetOpDesc();

vector<int64_t> input_memory_types;
(void) ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_INPUT_MEM_TYPE_LIST, input_memory_types);

if (input_memory_types.empty()) {
return SUCCESS;
}

for (int index = 0; index < input_memory_types.size() && index < op_desc->GetInputsSize(); index++) {
if (input_memory_types[index] != RT_MEMORY_P2P_DDR) {
continue;
}

auto hccl_in_anchor = GetInDataAnchor(index);
if (hccl_in_anchor == nullptr) {
continue;
}
auto src_out_anchor = hccl_in_anchor->GetPeerOutAnchor();
if (src_out_anchor == nullptr) {
GELOGE(INTERNAL_ERROR, "hcom op input has no peer anchor, node_name:%s", node->GetName().c_str());
return INTERNAL_ERROR;
}

if (IsDataNode(src_out_anchor->GetOwnerNode()->GetType())) {
Status ret = ModifyEdgeConnection(graph, src_out_anchor, hccl_in_anchor);
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Failed to modify the connection.");
return ret;
}
}
}
return SUCCESS;
}

bool HcclMemcpyPass::IsDataNode(const std::string& node_type) {
return (node_type == CONSTANTOP) || (node_type == VARIABLE) || (node_type == DATA) || (node_type == CONSTANT);
}

///
/// @brief Add MemcpyAsync Node
/// @param [in] ge::ComputeGraphPtr graph


+ 7
- 1
ge/graph/passes/hccl_memcpy_pass.h View File

@@ -37,7 +37,13 @@ class HcclMemcpyPass : public GraphPass {
Status ModifyEdgeConnection(const ComputeGraphPtr &graph, const OutDataAnchorPtr &src_out_anchor,
const InDataAnchorPtr &hccl_in_anchor);

Status ProcessBroadcastMemcpy(const ComputeGraphPtr &graph, const NodePtr node);
Status ContinuousInputProcess(const ComputeGraphPtr &graph, const NodePtr node);

Status MutableInputProcess(const ComputeGraphPtr &graph, const NodePtr node);

Status P2pmemInputProcess(const ComputeGraphPtr &graph, const NodePtr node);

bool HcclMemcpyPass::IsDataNode(const std::string& node_type);

std::unordered_map<std::string, uint32_t> node_num_map_;
};


Loading…
Cancel
Save