@@ -488,6 +488,9 @@ void DavinciModel::InitRuntimeParams() { | |||||
session_scope_mem_info.memory_size = static_cast<size_t>(ret ? value : 0); | session_scope_mem_info.memory_size = static_cast<size_t>(ret ? value : 0); | ||||
runtime_param_.memory_infos[kSessionScopeMemory | RT_MEMORY_HBM] = std::move(session_scope_mem_info); | runtime_param_.memory_infos[kSessionScopeMemory | RT_MEMORY_HBM] = std::move(session_scope_mem_info); | ||||
ret = ge::AttrUtils::GetInt(ge_model_, ATTR_MODEL_ZERO_COPY_MEMORY_SIZE, value); | |||||
runtime_param_.zero_copy_size = ret ? (uint64_t)value : 0; | |||||
GELOGI("InitRuntimeParams(), %s.", runtime_param_.ToString().c_str()); | GELOGI("InitRuntimeParams(), %s.", runtime_param_.ToString().c_str()); | ||||
} | } | ||||
@@ -265,6 +265,9 @@ class DavinciModel { | |||||
size_t TotalVarMemSize() const { return runtime_param_.var_size; } | size_t TotalVarMemSize() const { return runtime_param_.var_size; } | ||||
// get total zero copy size | |||||
size_t TotalZeroCopySize() const { return runtime_param_.zero_copy_size; } | |||||
// get base memory address | // get base memory address | ||||
uint8_t *MemBase() { return mem_base_; } | uint8_t *MemBase() { return mem_base_; } | ||||
@@ -49,7 +49,7 @@ struct RuntimeParam { | |||||
<< ", label_num:" << label_num << ", logic_mem_base:" << logic_mem_base | << ", label_num:" << label_num << ", logic_mem_base:" << logic_mem_base | ||||
<< ", logic_weight_base:" << logic_weight_base << ", logic_var_base:" << logic_var_base | << ", logic_weight_base:" << logic_weight_base << ", logic_var_base:" << logic_var_base | ||||
<< ", memory_size:" << mem_size << ", weight_size:" << weight_size << ", var_size:" << var_size | << ", memory_size:" << mem_size << ", weight_size:" << weight_size << ", var_size:" << var_size | ||||
<< ", ex_memory_info:"; | |||||
<< ", zero_copy_size:" << zero_copy_size << ", ex_memory_info:"; | |||||
for (auto it : memory_infos) { | for (auto it : memory_infos) { | ||||
ss << "[memory_type:" << it.first << ", memory_size:" << it.second.memory_size << "]"; | ss << "[memory_type:" << it.first << ", memory_size:" << it.second.memory_size << "]"; | ||||
} | } | ||||
@@ -65,6 +65,7 @@ struct RuntimeParam { | |||||
uint64_t var_size = 0; | uint64_t var_size = 0; | ||||
uint64_t logic_var_base = 0; | uint64_t logic_var_base = 0; | ||||
uint8_t *var_base = nullptr; | uint8_t *var_base = nullptr; | ||||
size_t zero_copy_size = 0; | |||||
std::map<uint64_t, MemInfo> memory_infos; | std::map<uint64_t, MemInfo> memory_infos; | ||||
uint32_t batch_num = 0; | uint32_t batch_num = 0; | ||||
uint32_t stream_num = 0; | uint32_t stream_num = 0; | ||||
@@ -101,10 +101,12 @@ Status KnownNodeTask::Init(TaskContext &context) { | |||||
GE_CHK_STATUS_RET(context.AllocateOutputs(), "[Allocate][Outputs] failed for %s.", context.GetNodeName()); | GE_CHK_STATUS_RET(context.AllocateOutputs(), "[Allocate][Outputs] failed for %s.", context.GetNodeName()); | ||||
// allocate mem base | // allocate mem base | ||||
void *buffer = nullptr; | void *buffer = nullptr; | ||||
if (davinci_model_->TotalMemSize() != 0) { | |||||
size_t total_mem_size = davinci_model_->TotalMemSize(); | |||||
size_t total_zero_copy_size = davinci_model_->TotalZeroCopySize(); | |||||
if (total_mem_size != 0 && total_mem_size > total_zero_copy_size) { | |||||
RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), | RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), | ||||
"[KnownNodeTask_AllocateWorkspace] Start"); | "[KnownNodeTask_AllocateWorkspace] Start"); | ||||
GE_CHK_STATUS_RET(context.AllocateWorkspace(davinci_model_->TotalMemSize(), &buffer, | |||||
GE_CHK_STATUS_RET(context.AllocateWorkspace(total_mem_size - total_zero_copy_size, &buffer, | |||||
davinci_model_->GetRuntimeParam().mem_base), | davinci_model_->GetRuntimeParam().mem_base), | ||||
"[Allocate][Workspace] failed for %s.", context.GetNodeName()); | "[Allocate][Workspace] failed for %s.", context.GetNodeName()); | ||||
RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), | RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), | ||||
@@ -1 +1 @@ | |||||
Subproject commit 1aa10c59b4e11564c2db76c2ba0039474d38df26 | |||||
Subproject commit 7cb171b9c511fec57ccc0ad746ef2126267fe18b |
@@ -1 +1 @@ | |||||
Subproject commit 7773435b776fb37231abcef2bbcf972814b01dd1 | |||||
Subproject commit 8d44bebfeeb71b793bc7325acc95345090789e19 |