@@ -202,7 +202,7 @@ Status RdmaPoolAllocator::GetBaseAddr(uint64_t &base_addr, uint64_t &mem_size) { | |||||
GELOGE(INTERNAL_ERROR, "Rdma base addr is nullptr."); | GELOGE(INTERNAL_ERROR, "Rdma base addr is nullptr."); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
base_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(rdma_base_addr_)); | |||||
base_addr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(rdma_base_addr_)); | |||||
mem_size = rdma_mem_size_; | mem_size = rdma_mem_size_; | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -701,6 +701,9 @@ Status HybridModelBuilder::LoadGraph() { | |||||
GE_CHK_STATUS_RET(IdentifyVariableOutputs(*parent_node_item), | GE_CHK_STATUS_RET(IdentifyVariableOutputs(*parent_node_item), | ||||
"[%s] Failed to identify ref outputs.", | "[%s] Failed to identify ref outputs.", | ||||
parent_node_item->NodeName().c_str()); | parent_node_item->NodeName().c_str()); | ||||
GE_CHK_STATUS_RET(IdentifySameInputs(*parent_node_item), | |||||
"[%s] Failed to identify same outputs.", | |||||
parent_node_item->NodeName().c_str()); | |||||
// if parent is function control op. need add a virtual partitioned call | // if parent is function control op. need add a virtual partitioned call | ||||
if (parent_node_item->IsControlOp()) { | if (parent_node_item->IsControlOp()) { | ||||
@@ -1162,6 +1165,46 @@ Status HybridModelBuilder::InitRuntimeParams() { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status HybridModelBuilder::IdentifySameInputs(NodeItem &node_item) { | |||||
GELOGD("Start to parse same inputs on net output: %s", node_item.NodeName().c_str()); | |||||
auto subgraph = NodeUtils::GetSubgraph(*node_item.node, kSubgraphIndex); | |||||
GE_CHECK_NOTNULL(subgraph); | |||||
auto net_output_node = subgraph->FindFirstNodeMatchType(NETOUTPUT); | |||||
if (net_output_node == nullptr) { | |||||
GELOGD("Subgraph [%s] does not have net output", subgraph->GetName().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
auto net_output_desc = net_output_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(net_output_desc); | |||||
std::map<std::string, int> connected_inputs; | |||||
for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) { | |||||
auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
if (out_data_anchor == nullptr) { | |||||
continue; | |||||
} | |||||
auto src_node = out_data_anchor->GetOwnerNode(); | |||||
GE_CHECK_NOTNULL(src_node); | |||||
auto op_desc = src_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
std::string input_key = std::to_string(op_desc->GetId()) + "_" + std::to_string(out_data_anchor->GetIdx()); | |||||
auto it = connected_inputs.find(input_key); | |||||
if (it == connected_inputs.end()) { | |||||
connected_inputs.emplace(input_key, in_data_anchor->GetIdx()); | |||||
} else { | |||||
GELOGD("[%s] output [%d] reuse output [%d] input node = %s, idx = %d.", node_item.NodeName().c_str(), | |||||
in_data_anchor->GetIdx(), | |||||
it->second, | |||||
src_node->GetName().c_str(), | |||||
out_data_anchor->GetIdx()); | |||||
node_item.reuse_outputs.emplace(in_data_anchor->GetIdx(), it->second); | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item) { | Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item) { | ||||
GELOGD("Start to parse outputs of node: %s", node_item.NodeName().c_str()); | GELOGD("Start to parse outputs of node: %s", node_item.NodeName().c_str()); | ||||
auto subgraph = NodeUtils::GetSubgraph(*node_item.node, kSubgraphIndex); | auto subgraph = NodeUtils::GetSubgraph(*node_item.node, kSubgraphIndex); | ||||
@@ -59,6 +59,7 @@ class HybridModelBuilder { | |||||
Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); | Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); | ||||
Status LoadTasks(); | Status LoadTasks(); | ||||
Status IdentifyVariableOutputs(NodeItem &node_item); | Status IdentifyVariableOutputs(NodeItem &node_item); | ||||
Status IdentifySameInputs(NodeItem &node_item); | |||||
Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); | Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); | ||||
Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); | Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); | ||||
Status ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies); | Status ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies); | ||||
@@ -83,6 +83,7 @@ struct NodeItem { | |||||
const NodeExecutor *node_executor = nullptr; | const NodeExecutor *node_executor = nullptr; | ||||
std::map<int, ge::NodePtr> ref_outputs; | std::map<int, ge::NodePtr> ref_outputs; | ||||
std::map<int, int> reuse_inputs; | std::map<int, int> reuse_inputs; | ||||
std::map<int, int> reuse_outputs; | |||||
std::vector<bool> is_input_shape_static; | std::vector<bool> is_input_shape_static; | ||||
bool is_output_shape_static = true; | bool is_output_shape_static = true; | ||||
@@ -189,13 +189,20 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccess | |||||
} | } | ||||
GE_CHECK_NOTNULL(tv); | GE_CHECK_NOTNULL(tv); | ||||
auto local_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(tv->MutableData())); | auto local_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(tv->MutableData())); | ||||
addr_infos.resize(dims.front()); | |||||
for (auto idx = 0; idx < dims.front(); ++idx) { | |||||
auto row_num = dims.front(); | |||||
addr_infos.resize(row_num); | |||||
auto device_len = tv->GetSize() / row_num; | |||||
if (device_len <= 0 || device_len > data[kVarTableIdxLen]) { | |||||
GELOGE(FAILED, "Local embedding length is out of range."); | |||||
return FAILED; | |||||
} | |||||
for (auto idx = 0; idx < row_num; ++idx) { | |||||
FMK_INT64_MULCHECK(idx, kVarTableRowCnt); | FMK_INT64_MULCHECK(idx, kVarTableRowCnt); | ||||
auto line_idx = idx * kVarTableRowCnt; | auto line_idx = idx * kVarTableRowCnt; | ||||
addr_infos[idx] = {static_cast<uint32_t>(data[line_idx]), data[line_idx + kVarTableIdxAddr], local_addr, | addr_infos[idx] = {static_cast<uint32_t>(data[line_idx]), data[line_idx + kVarTableIdxAddr], local_addr, | ||||
data[line_idx + kVarTableIdxLen]}; | |||||
local_addr += data[line_idx + kVarTableIdxLen]; | |||||
device_len}; | |||||
local_addr += device_len; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -221,16 +221,22 @@ Status TaskContext::AllocateOutput(int index, | |||||
GE_CHECK_NOTNULL(ref_tensor); | GE_CHECK_NOTNULL(ref_tensor); | ||||
outputs_start_[index] = *ref_tensor; | outputs_start_[index] = *ref_tensor; | ||||
} else { | } else { | ||||
auto reuse_input = node_item_->reuse_inputs.find(index); | |||||
if (reuse_input != node_item_->reuse_inputs.end()) { | |||||
GELOGD("[%s] Output[%d] is referenced to input[%d]", GetNodeName(), index, reuse_input->second); | |||||
outputs_start_[index] = inputs_start_[reuse_input->second]; | |||||
auto reuse_output_it = node_item_->reuse_outputs.find(index); | |||||
if (reuse_output_it != node_item_->reuse_outputs.end()) { | |||||
GELOGD("[%s] reuse output [%d] with output [%d]", GetNodeName(), index, reuse_output_it->second); | |||||
outputs_start_[index] = outputs_start_[reuse_output_it->second]; | |||||
} else { | } else { | ||||
GE_CHK_STATUS_RET_NOLOG(AllocateTensor(tensor_desc, outputs_start_[index], attr)); | |||||
GELOGD("Allocating output successfully. node: %s. index = %d, size = %zu", | |||||
node_item_->NodeName().c_str(), | |||||
index, | |||||
outputs_start_[index].GetSize()); | |||||
auto reuse_input = node_item_->reuse_inputs.find(index); | |||||
if (reuse_input != node_item_->reuse_inputs.end()) { | |||||
GELOGD("[%s] Output[%d] is referenced to input[%d]", GetNodeName(), index, reuse_input->second); | |||||
outputs_start_[index] = inputs_start_[reuse_input->second]; | |||||
} else { | |||||
GE_CHK_STATUS_RET_NOLOG(AllocateTensor(tensor_desc, outputs_start_[index], attr)); | |||||
GELOGD("Allocating output successfully. node: %s. index = %d, size = %zu", | |||||
node_item_->NodeName().c_str(), | |||||
index, | |||||
outputs_start_[index].GetSize()); | |||||
} | |||||
} | } | ||||
} | } | ||||