@@ -38,10 +38,8 @@ const std::string kProfModelUnsubscribe = "prof_model_cancel_subscribe"; | |||||
} // namespace | } // namespace | ||||
namespace ge { | namespace ge { | ||||
ProfilingManager::ProfilingManager() : is_load_profiling_(false), | |||||
is_execute_profiling_(false), | |||||
is_training_trace_(false), | |||||
subscribe_count_(0) { | |||||
ProfilingManager::ProfilingManager() | |||||
: is_load_profiling_(false), is_execute_profiling_(false), is_training_trace_(false), subscribe_count_(0) { | |||||
prof_cb_.msprofCtrlCallback = nullptr; | prof_cb_.msprofCtrlCallback = nullptr; | ||||
prof_cb_.msprofReporterCallback = nullptr; | prof_cb_.msprofReporterCallback = nullptr; | ||||
} | } | ||||
@@ -91,19 +89,18 @@ ge::Status ProfilingManager::InitFromOptions(const Options &options, MsprofGeOpt | |||||
#ifdef DAVINCI_SUPPORT_PROFILING | #ifdef DAVINCI_SUPPORT_PROFILING | ||||
// enable profiling by env | // enable profiling by env | ||||
char env_profiling_mode[MMPA_MAX_PATH] = { 0x00 }; | char env_profiling_mode[MMPA_MAX_PATH] = { 0x00 }; | ||||
is_load_profiling_ = false; // Change in ProfInit | |||||
is_execute_profiling_ = false; | is_execute_profiling_ = false; | ||||
if (options.profiling_mode == "1" && !options.profiling_options.empty()) { | if (options.profiling_mode == "1" && !options.profiling_options.empty()) { | ||||
// enable profiling by ge option | // enable profiling by ge option | ||||
if (memcpy_s(prof_conf.options, MSPROF_OPTIONS_DEF_LEN_MAX, options.profiling_options.c_str(), | |||||
options.profiling_options.size()) != EOK) { | |||||
if (strncpy_s(prof_conf.options, MSPROF_OPTIONS_DEF_LEN_MAX, options.profiling_options.c_str(), | |||||
MSPROF_OPTIONS_DEF_LEN_MAX - 1) != EOK) { | |||||
GELOGE(INTERNAL_ERROR, "copy profiling_options failed."); | GELOGE(INTERNAL_ERROR, "copy profiling_options failed."); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
is_execute_profiling_ = true; | is_execute_profiling_ = true; | ||||
GELOGI("The profiling in options is %s, %s. origin option: %s", options.profiling_mode.c_str(), | |||||
prof_conf.options, options.profiling_options.c_str()); | |||||
GELOGI("The profiling in options is %s, %s. origin option: %s", options.profiling_mode.c_str(), prof_conf.options, | |||||
options.profiling_options.c_str()); | |||||
} else { | } else { | ||||
(void)mmGetEnv("PROFILING_MODE", env_profiling_mode, MMPA_MAX_PATH); | (void)mmGetEnv("PROFILING_MODE", env_profiling_mode, MMPA_MAX_PATH); | ||||
(void)mmGetEnv("PROFILING_OPTIONS", prof_conf.options, MSPROF_OPTIONS_DEF_LEN_MAX); | (void)mmGetEnv("PROFILING_OPTIONS", prof_conf.options, MSPROF_OPTIONS_DEF_LEN_MAX); | ||||
@@ -127,11 +124,12 @@ ge::Status ProfilingManager::InitFromOptions(const Options &options, MsprofGeOpt | |||||
return ge::PARAM_INVALID; | return ge::PARAM_INVALID; | ||||
} | } | ||||
if (memcpy_s(prof_conf.jobId, sizeof(prof_conf.jobId), options.job_id.c_str(), | |||||
sizeof(options.job_id.c_str())) != EOK) { | |||||
if (strncpy_s(prof_conf.jobId, MSPROF_OPTIONS_DEF_LEN_MAX, options.job_id.c_str(), | |||||
MSPROF_OPTIONS_DEF_LEN_MAX - 1) != EOK) { | |||||
GELOGE(INTERNAL_ERROR, "copy job_id failed."); | GELOGE(INTERNAL_ERROR, "copy job_id failed."); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
GELOGI("Job id: %s, original job id: %s.", prof_conf.jobId, options.job_id.c_str()); | |||||
#endif | #endif | ||||
return ge::SUCCESS; | return ge::SUCCESS; | ||||
} | } | ||||
@@ -143,6 +141,9 @@ ge::Status ProfilingManager::ParseOptions(const std::string &options) { | |||||
} | } | ||||
try { | try { | ||||
Json prof_options = Json::parse(options); | Json prof_options = Json::parse(options); | ||||
if (options.find(kTrainingTrace) == std::string::npos) { | |||||
return ge::SUCCESS; | |||||
} | |||||
const std::string training_trace = prof_options[kTrainingTrace]; | const std::string training_trace = prof_options[kTrainingTrace]; | ||||
if (training_trace.empty()) { | if (training_trace.empty()) { | ||||
GELOGI("Training trace will not take effect."); | GELOGI("Training trace will not take effect."); | ||||
@@ -158,6 +159,7 @@ ge::Status ProfilingManager::ParseOptions(const std::string &options) { | |||||
if (!fp_point_.empty() && !bp_point_.empty()) { | if (!fp_point_.empty() && !bp_point_.empty()) { | ||||
GELOGI("Training trace bp fp is set, bp_point:%s, fp_point:%s.", bp_point_.c_str(), fp_point_.c_str()); | GELOGI("Training trace bp fp is set, bp_point:%s, fp_point:%s.", bp_point_.c_str(), fp_point_.c_str()); | ||||
} | } | ||||
is_training_trace_ = true; | |||||
} catch (...) { | } catch (...) { | ||||
GELOGE(FAILED, "Json prof_conf options is invalid."); | GELOGE(FAILED, "Json prof_conf options is invalid."); | ||||
return ge::PARAM_INVALID; | return ge::PARAM_INVALID; | ||||
@@ -627,6 +629,10 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfSt | |||||
uint64_t module, const std::map<std::string, std::string> &config_para) { | uint64_t module, const std::map<std::string, std::string> &config_para) { | ||||
#ifdef DAVINCI_SUPPORT_PROFILING | #ifdef DAVINCI_SUPPORT_PROFILING | ||||
std::lock_guard<std::mutex> lock(mutex_); | std::lock_guard<std::mutex> lock(mutex_); | ||||
uint64_t training_trace_mask = module & PROF_TRAINING_TRACE_MASK; | |||||
if (training_trace_mask == PROF_TRAINING_TRACE_MASK) { | |||||
is_training_trace_ = true; | |||||
} | |||||
int32_t device_num = 0; | int32_t device_num = 0; | ||||
vector<int32_t> device_list; | vector<int32_t> device_list; | ||||
if (ProfParseParam(config_para, device_num, device_list) != SUCCESS) { | if (ProfParseParam(config_para, device_num, device_list) != SUCCESS) { | ||||
@@ -402,6 +402,7 @@ Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node, | |||||
GE_ERRORLOG_AND_ERRORMSG(FAILED, error.c_str()); | GE_ERRORLOG_AND_ERRORMSG(FAILED, error.c_str()); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
continuous_mem_start = iter->second.mem_offset_; | |||||
for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | ||||
auto peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor(); | auto peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor(); | ||||
GE_IF_BOOL_EXEC(peer_out_data_anchor == nullptr, continue); | GE_IF_BOOL_EXEC(peer_out_data_anchor == nullptr, continue); | ||||
@@ -567,7 +567,7 @@ Status TaskGenerator::MarkFirstAndLastOps(const vector<OpDescPtr> &ops, bool is_ | |||||
continue; | continue; | ||||
} | } | ||||
string op_type = op_desc->GetType(); | string op_type = op_desc->GetType(); | ||||
if (!op_desc->GetSubgraphInstanceNames().empty() || separator_types.count(op_type) != 0) { | |||||
if ((!is_single_stream && !op_desc->GetSubgraphInstanceNames().empty()) || separator_types.count(op_type) != 0) { | |||||
continuous_op_lists.emplace_back(vector<OpDescPtr>()); | continuous_op_lists.emplace_back(vector<OpDescPtr>()); | ||||
} else { | } else { | ||||
continuous_op_lists.back().emplace_back(op_desc); | continuous_op_lists.back().emplace_back(op_desc); | ||||
@@ -23,7 +23,10 @@ | |||||
namespace { | namespace { | ||||
const int kInvalidTransopDataIndex = -1; | const int kInvalidTransopDataIndex = -1; | ||||
const int kTransOpOutIndex = 0; | const int kTransOpOutIndex = 0; | ||||
std::map<ge::DataType, ge::DataType> precision_loss_transfer_map = {{ge::DT_FLOAT, ge::DT_BOOL}}; | |||||
std::map<ge::DataType, ge::DataType> precision_loss_transfer_map = { | |||||
{ge::DT_FLOAT, ge::DT_BOOL}, | |||||
{ge::DT_INT64, ge::DT_BOOL} | |||||
}; | |||||
} // namespace | } // namespace | ||||
namespace ge { | namespace ge { | ||||
@@ -320,10 +320,10 @@ Status GraphLoader::GetMemoryInfo(int64_t &free) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status GraphLoader::DestroyAicpuKernel(uint64_t session_id, uint32_t model_id) { | |||||
Status GraphLoader::DestroyAicpuKernel(uint64_t session_id, uint32_t model_id, uint32_t sub_model_id) { | |||||
auto model_manager = ModelManager::GetInstance(); | auto model_manager = ModelManager::GetInstance(); | ||||
GE_CHECK_NOTNULL(model_manager); | GE_CHECK_NOTNULL(model_manager); | ||||
Status ret = model_manager->DestroyAicpuKernel(session_id, model_id); | |||||
Status ret = model_manager->DestroyAicpuKernel(session_id, model_id, sub_model_id); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "Destroy aicpu kernel failed."); | GELOGE(ret, "Destroy aicpu kernel failed."); | ||||
return ret; | return ret; | ||||
@@ -68,7 +68,7 @@ class GraphLoader { | |||||
const std::vector<GeTensorDesc> &input_desc, OutputData &output_data, | const std::vector<GeTensorDesc> &input_desc, OutputData &output_data, | ||||
std::vector<GeTensorDesc> &output_desc); | std::vector<GeTensorDesc> &output_desc); | ||||
static Status DestroyAicpuKernel(uint64_t session_id, uint32_t model_id); | |||||
static Status DestroyAicpuKernel(uint64_t session_id, uint32_t model_id, uint32_t sub_model_id); | |||||
static Status DestroyAicpuSessionForInfer(uint32_t model_id); | static Status DestroyAicpuSessionForInfer(uint32_t model_id); | ||||
@@ -734,7 +734,6 @@ Status DavinciModel::ReportProfilingData() { | |||||
} | } | ||||
ProfilingManager::Instance().ReportProfilingData(model_id_, GetTaskDescInfo(), compute_graph_desc_info); | ProfilingManager::Instance().ReportProfilingData(model_id_, GetTaskDescInfo(), compute_graph_desc_info); | ||||
GE_CHK_STATUS(SinkModelProfile(), "Sink model profiler failed."); | GE_CHK_STATUS(SinkModelProfile(), "Sink model profiler failed."); | ||||
op_list_.clear(); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -136,6 +136,20 @@ class DavinciModel { | |||||
/// | /// | ||||
void SetId(uint32_t model_id) { model_id_ = model_id; } | void SetId(uint32_t model_id) { model_id_ = model_id; } | ||||
/// | |||||
/// @ingroup ge | |||||
/// @brief Get SubModelId | |||||
/// @return sub model ID | |||||
/// | |||||
uint32_t SubModelId() const { return sub_model_id_; } | |||||
/// | |||||
/// @ingroup ge | |||||
/// @brief Set SubModelId | |||||
/// @return sub model ID | |||||
/// | |||||
void SetSubModelId(uint32_t sub_model_id) { sub_model_id_ = sub_model_id; } | |||||
static void *Run(DavinciModel *model_pointer); | static void *Run(DavinciModel *model_pointer); | ||||
/// | /// | ||||
@@ -815,6 +829,7 @@ class DavinciModel { | |||||
uint32_t model_id_; | uint32_t model_id_; | ||||
uint32_t runtime_model_id_; | uint32_t runtime_model_id_; | ||||
uint32_t sub_model_id_ = 0; | |||||
string name_; | string name_; | ||||
// used for inference data dump | // used for inference data dump | ||||
@@ -81,7 +81,8 @@ ModelManager::ModelManager() { | |||||
session_id_bias_ = 0; | session_id_bias_ = 0; | ||||
} | } | ||||
Status ModelManager::KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType op_type, uint64_t session_id, uint32_t model_id) { | |||||
Status ModelManager::KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType op_type, uint64_t session_id, uint32_t model_id, | |||||
uint32_t sub_model_id) { | |||||
STR_FWK_OP_KERNEL param_base = {}; | STR_FWK_OP_KERNEL param_base = {}; | ||||
void *devicebase = nullptr; | void *devicebase = nullptr; | ||||
void *aicpu_kernel_addr = nullptr; | void *aicpu_kernel_addr = nullptr; | ||||
@@ -91,10 +92,11 @@ Status ModelManager::KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType op_type, u | |||||
param_base.fwkKernelBase.fwk_kernel.sessionID = session_id; | param_base.fwkKernelBase.fwk_kernel.sessionID = session_id; | ||||
if (op_type == aicpu::FWKAdapter::FWKOperateType::FWK_ADPT_KERNEL_DESTROY) { | if (op_type == aicpu::FWKAdapter::FWKOperateType::FWK_ADPT_KERNEL_DESTROY) { | ||||
std::vector<uint64_t> v_aicpu_kernel; | std::vector<uint64_t> v_aicpu_kernel; | ||||
std::string model_key = std::to_string(session_id) + "_" + std::to_string(model_id); | |||||
std::string model_key = std::to_string(session_id) + "_" + std::to_string(model_id) + "_" + | |||||
std::to_string(sub_model_id); | |||||
auto iter = model_aicpu_kernel_.find(model_key); | auto iter = model_aicpu_kernel_.find(model_key); | ||||
if (iter != model_aicpu_kernel_.end()) { | if (iter != model_aicpu_kernel_.end()) { | ||||
GELOGD("kernel destroy session_id %lu, model_id %u.", session_id, model_id); | |||||
GELOGD("kernel destroy session_id %lu, model_id %u, sub_model_id %u..", session_id, model_id, sub_model_id); | |||||
v_aicpu_kernel = model_aicpu_kernel_.at(model_key); | v_aicpu_kernel = model_aicpu_kernel_.at(model_key); | ||||
// Insert size of aicpu kernel vector in the first element | // Insert size of aicpu kernel vector in the first element | ||||
v_aicpu_kernel.insert(v_aicpu_kernel.begin(), v_aicpu_kernel.size()); | v_aicpu_kernel.insert(v_aicpu_kernel.begin(), v_aicpu_kernel.size()); | ||||
@@ -192,7 +194,7 @@ void ModelManager::DestroyAicpuSession(uint64_t session_id) { | |||||
GE_CHK_RT(rtSetDevice(static_cast<int32_t>(GetContext().DeviceId()))); | GE_CHK_RT(rtSetDevice(static_cast<int32_t>(GetContext().DeviceId()))); | ||||
} | } | ||||
Status ret = KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType::FWK_ADPT_SESSION_DESTROY, session_id, 0); | |||||
Status ret = KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType::FWK_ADPT_SESSION_DESTROY, session_id, 0, 0); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGW("The session: %lu destroy failed.", session_id); | GELOGW("The session: %lu destroy failed.", session_id); | ||||
} else { | } else { | ||||
@@ -218,20 +220,22 @@ ge::Status ModelManager::DestroyAicpuSessionForInfer(uint32_t model_id) { | |||||
auto it = model_map_.find(model_id); | auto it = model_map_.find(model_id); | ||||
if (it == model_map_.end()) { | if (it == model_map_.end()) { | ||||
GELOGE(GE_EXEC_MODEL_ID_INVALID, "model id %u does not exists.", model_id); | |||||
return GE_EXEC_MODEL_ID_INVALID; | |||||
GELOGE(ACL_ERROR_GE_EXEC_MODEL_ID_INVALID, "model id %u does not exists.", model_id); | |||||
return ACL_ERROR_GE_EXEC_MODEL_ID_INVALID; | |||||
} | } | ||||
uint64_t session_id = it->second->GetSessionId(); | uint64_t session_id = it->second->GetSessionId(); | ||||
DestroyAicpuSession(session_id); | DestroyAicpuSession(session_id); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
ge::Status ModelManager::DestroyAicpuKernel(uint64_t session_id, uint32_t model_id) { | |||||
ge::Status ModelManager::DestroyAicpuKernel(uint64_t session_id, uint32_t model_id, uint32_t sub_model_id) { | |||||
GELOGD("destroy aicpu kernel in session_id %lu, model_id %u.", session_id, model_id); | GELOGD("destroy aicpu kernel in session_id %lu, model_id %u.", session_id, model_id); | ||||
std::lock_guard<std::mutex> lock(map_mutex_); | std::lock_guard<std::mutex> lock(map_mutex_); | ||||
std::string model_key = std::to_string(session_id) + "_" + std::to_string(model_id); | |||||
std::string model_key = std::to_string(session_id) + "_" + std::to_string(model_id) + "_" + | |||||
std::to_string(sub_model_id); | |||||
if (model_aicpu_kernel_.find(model_key) != model_aicpu_kernel_.end()) { | if (model_aicpu_kernel_.find(model_key) != model_aicpu_kernel_.end()) { | ||||
Status ret = KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType::FWK_ADPT_KERNEL_DESTROY, session_id, model_id); | |||||
Status ret = KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType::FWK_ADPT_KERNEL_DESTROY, session_id, model_id, | |||||
sub_model_id); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(FAILED, "Destroy aicpu kernel failed."); | GELOGE(FAILED, "Destroy aicpu kernel failed."); | ||||
return FAILED; | return FAILED; | ||||
@@ -240,10 +244,12 @@ ge::Status ModelManager::DestroyAicpuKernel(uint64_t session_id, uint32_t model_ | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
ge::Status ModelManager::CreateAicpuKernel(uint64_t session_id, uint32_t model_id, uint64_t kernel_id) { | |||||
ge::Status ModelManager::CreateAicpuKernel(uint64_t session_id, uint32_t model_id, uint32_t sub_model_id, | |||||
uint64_t kernel_id) { | |||||
std::lock_guard<std::mutex> lock(map_mutex_); | std::lock_guard<std::mutex> lock(map_mutex_); | ||||
std::vector<uint64_t> v_aicpu_kernel; | std::vector<uint64_t> v_aicpu_kernel; | ||||
std::string model_key = std::to_string(session_id) + "_" + std::to_string(model_id); | |||||
std::string model_key = std::to_string(session_id) + "_" + std::to_string(model_id) + "_" + | |||||
std::to_string(sub_model_id); | |||||
if (model_aicpu_kernel_.find(model_key) != model_aicpu_kernel_.end()) { | if (model_aicpu_kernel_.find(model_key) != model_aicpu_kernel_.end()) { | ||||
v_aicpu_kernel = model_aicpu_kernel_.at(model_key); | v_aicpu_kernel = model_aicpu_kernel_.at(model_key); | ||||
} | } | ||||
@@ -378,7 +384,8 @@ Status ModelManager::DeleteModel(uint32_t id) { | |||||
auto hybrid_model_it = hybrid_model_map_.find(id); | auto hybrid_model_it = hybrid_model_map_.find(id); | ||||
if (it != model_map_.end()) { | if (it != model_map_.end()) { | ||||
uint64_t session_id = it->second->GetSessionId(); | uint64_t session_id = it->second->GetSessionId(); | ||||
std::string model_key = std::to_string(session_id) + "_" + std::to_string(id); | |||||
std::string model_key = std::to_string(session_id) + "_" + std::to_string(id) + "_" + | |||||
std::to_string(it->second->SubModelId()); | |||||
auto iter_aicpu_kernel = model_aicpu_kernel_.find(model_key); | auto iter_aicpu_kernel = model_aicpu_kernel_.find(model_key); | ||||
if (iter_aicpu_kernel != model_aicpu_kernel_.end()) { | if (iter_aicpu_kernel != model_aicpu_kernel_.end()) { | ||||
(void)model_aicpu_kernel_.erase(iter_aicpu_kernel); | (void)model_aicpu_kernel_.erase(iter_aicpu_kernel); | ||||
@@ -905,7 +912,7 @@ Status ModelManager::GetInputOutputDescInfo(const uint32_t model_id, vector<Inpu | |||||
} | } | ||||
std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); | std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); | ||||
GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, GE_EXEC_MODEL_ID_INVALID, | |||||
GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, ACL_ERROR_GE_EXEC_MODEL_ID_INVALID, | |||||
"GetInputOutputDescInfo Failed, Invalid model id %u!", model_id); | "GetInputOutputDescInfo Failed, Invalid model id %u!", model_id); | ||||
davinci_model->SetModelDescVersion(new_model_desc); | davinci_model->SetModelDescVersion(new_model_desc); | ||||
@@ -1224,7 +1231,8 @@ Status ModelManager::ExecuteModel(uint32_t model_id, rtStream_t stream, bool asy | |||||
// Zero copy is enabled by default, no need to judge. | // Zero copy is enabled by default, no need to judge. | ||||
uint64_t session_id_davinci = davinci_model->GetSessionId(); | uint64_t session_id_davinci = davinci_model->GetSessionId(); | ||||
uint32_t model_id_davinci = davinci_model->GetModelId(); | uint32_t model_id_davinci = davinci_model->GetModelId(); | ||||
Status status = DestroyAicpuKernel(session_id_davinci, model_id_davinci); | |||||
uint32_t sub_model_id = davinci_model->SubModelId(); | |||||
Status status = DestroyAicpuKernel(session_id_davinci, model_id_davinci, sub_model_id); | |||||
if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
GELOGW("Destroy specified aicpu kernel failed, session id is %lu, model id is %u.", session_id_davinci, | GELOGW("Destroy specified aicpu kernel failed, session id is %lu, model id is %u.", session_id_davinci, | ||||
model_id_davinci); | model_id_davinci); | ||||
@@ -1244,7 +1252,7 @@ Status ModelManager::CreateAicpuSession(uint64_t session_id) { | |||||
auto it = sess_ids_.find(session_id); | auto it = sess_ids_.find(session_id); | ||||
// never been created by any model | // never been created by any model | ||||
if (it == sess_ids_.end()) { | if (it == sess_ids_.end()) { | ||||
Status ret = KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType::FWK_ADPT_SESSION_CREATE, session_id, 0); | |||||
Status ret = KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType::FWK_ADPT_SESSION_CREATE, session_id, 0, 0); | |||||
if (ret == SUCCESS) { | if (ret == SUCCESS) { | ||||
(void)sess_ids_.insert(session_id); | (void)sess_ids_.insert(session_id); | ||||
GELOGI("The session: %lu create success.", session_id); | GELOGI("The session: %lu create success.", session_id); | ||||
@@ -1558,6 +1566,12 @@ Status ModelManager::LaunchKernelCheckAicpuOp(std::vector<std::string> &aicpu_op | |||||
size_t aicpu_op_nums = aicpu_optype_list.size(); | size_t aicpu_op_nums = aicpu_optype_list.size(); | ||||
size_t tf_op_nums = aicpu_tf_optype_list.size(); | size_t tf_op_nums = aicpu_tf_optype_list.size(); | ||||
size_t op_nums = aicpu_op_nums + tf_op_nums; | size_t op_nums = aicpu_op_nums + tf_op_nums; | ||||
std::function<void()> callback = [&]() { | |||||
for (auto mem : allocated_mem) { | |||||
GE_CHK_RT(rtFree(mem)); | |||||
} | |||||
}; | |||||
GE_MAKE_GUARD(release, callback); | |||||
// malloc sysOpInfoList in SysOpCheckInfo | // malloc sysOpInfoList in SysOpCheckInfo | ||||
status = rtMalloc(&d_req_op_list, op_nums * sizeof(SysOpInfo), RT_MEMORY_HBM); | status = rtMalloc(&d_req_op_list, op_nums * sizeof(SysOpInfo), RT_MEMORY_HBM); | ||||
if (status != RT_ERROR_NONE) { | if (status != RT_ERROR_NONE) { | ||||
@@ -1637,8 +1651,8 @@ Status ModelManager::LaunchKernelCheckAicpuOp(std::vector<std::string> &aicpu_op | |||||
return RT_ERROR_TO_GE_STATUS(status); | return RT_ERROR_TO_GE_STATUS(status); | ||||
} | } | ||||
allocated_mem.push_back(args); | allocated_mem.push_back(args); | ||||
GE_CHK_RT( | |||||
rtMemcpy(args, sizeof(SysOpCheckInfo), reinterpret_cast<void *>(&op_check_info_req), sizeof(SysOpCheckInfo), RT_MEMCPY_HOST_TO_DEVICE)); | |||||
GE_CHK_RT(rtMemcpy(args, sizeof(SysOpCheckInfo), reinterpret_cast<void *>(&op_check_info_req), sizeof(SysOpCheckInfo), | |||||
RT_MEMCPY_HOST_TO_DEVICE)); | |||||
GE_CHK_RT(rtMemcpy(reinterpret_cast<void *>(static_cast<uintptr_t>(static_cast<uint64_t>(reinterpret_cast<uintptr_t>(args)) + op_check_info_req.offSetLen)), | GE_CHK_RT(rtMemcpy(reinterpret_cast<void *>(static_cast<uintptr_t>(static_cast<uint64_t>(reinterpret_cast<uintptr_t>(args)) + op_check_info_req.offSetLen)), | ||||
sizeof(SysOpCheckResp), reinterpret_cast<void *>(&op_check_info_res), sizeof(SysOpCheckResp), RT_MEMCPY_HOST_TO_DEVICE)); | sizeof(SysOpCheckResp), reinterpret_cast<void *>(&op_check_info_res), sizeof(SysOpCheckResp), RT_MEMCPY_HOST_TO_DEVICE)); | ||||
GE_CHK_RT(rtStreamCreate(&stream, 0)); | GE_CHK_RT(rtStreamCreate(&stream, 0)); | ||||
@@ -1647,24 +1661,21 @@ Status ModelManager::LaunchKernelCheckAicpuOp(std::vector<std::string> &aicpu_op | |||||
status = rtStreamSynchronize(stream); | status = rtStreamSynchronize(stream); | ||||
if (status != RT_ERROR_NONE) { | if (status != RT_ERROR_NONE) { | ||||
GELOGE(RT_FAILED, "Call rt stream sync failed, status: 0x%x", status); | GELOGE(RT_FAILED, "Call rt stream sync failed, status: 0x%x", status); | ||||
GE_CHK_RT(rtStreamDestroy(stream)); | |||||
return RT_ERROR_TO_GE_STATUS(status); | return RT_ERROR_TO_GE_STATUS(status); | ||||
} | } | ||||
// Check the response | // Check the response | ||||
SysOpCheckResp *d_op_check_info_res = reinterpret_cast<SysOpCheckResp *>(reinterpret_cast<void *>(static_cast<uintptr_t>(static_cast<uint64_t>(reinterpret_cast<uintptr_t>(args)) + op_check_info_req.offSetLen))); | |||||
SysOpCheckResp *d_op_check_info_res = | |||||
reinterpret_cast<SysOpCheckResp *>(reinterpret_cast<void *>(static_cast<uintptr_t>(static_cast<uint64_t>( | |||||
reinterpret_cast<uintptr_t>(args)) + op_check_info_req.offSetLen))); | |||||
(void)memset_s(&op_check_info_res, sizeof(SysOpCheckResp), 0, sizeof(SysOpCheckResp)); | (void)memset_s(&op_check_info_res, sizeof(SysOpCheckResp), 0, sizeof(SysOpCheckResp)); | ||||
GE_CHK_RT(rtMemcpy(&op_check_info_res, sizeof(SysOpCheckResp), d_op_check_info_res, sizeof(SysOpCheckResp), | GE_CHK_RT(rtMemcpy(&op_check_info_res, sizeof(SysOpCheckResp), d_op_check_info_res, sizeof(SysOpCheckResp), | ||||
RT_MEMCPY_DEVICE_TO_HOST)); | RT_MEMCPY_DEVICE_TO_HOST)); | ||||
std::function<void()> callback = [&]() { | |||||
for (auto mem : allocated_mem) { | |||||
GE_CHK_RT(rtFree(mem)); | |||||
} | |||||
GE_CHK_RT(rtStreamDestroy(stream)); | |||||
}; | |||||
if (op_check_info_res.isWithoutJson) { | if (op_check_info_res.isWithoutJson) { | ||||
GELOGI("No need to check aicpu in this scenoria."); | GELOGI("No need to check aicpu in this scenoria."); | ||||
GE_MAKE_GUARD(release, callback); | |||||
GE_CHK_RT(rtStreamDestroy(stream)); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
uint64_t res_op_nums = op_check_info_res.opListNum; | uint64_t res_op_nums = op_check_info_res.opListNum; | ||||
@@ -1682,7 +1693,7 @@ Status ModelManager::LaunchKernelCheckAicpuOp(std::vector<std::string> &aicpu_op | |||||
sizeof(SysOpInfo) * res_op_nums, RT_MEMCPY_DEVICE_TO_HOST)); | sizeof(SysOpInfo) * res_op_nums, RT_MEMCPY_DEVICE_TO_HOST)); | ||||
if (res_ret_code_list.size() != res_aicpu_op_info_list.size() || res_ret_code_list.size() != res_op_nums) { | if (res_ret_code_list.size() != res_aicpu_op_info_list.size() || res_ret_code_list.size() != res_op_nums) { | ||||
GELOGE(FAILED, "Number of retcode is not equal to number of op type."); | GELOGE(FAILED, "Number of retcode is not equal to number of op type."); | ||||
GE_MAKE_GUARD(release, callback); | |||||
GE_CHK_RT(rtStreamDestroy(stream)); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
std::string fail_reason; | std::string fail_reason; | ||||
@@ -1705,11 +1716,11 @@ Status ModelManager::LaunchKernelCheckAicpuOp(std::vector<std::string> &aicpu_op | |||||
} | } | ||||
fail_reason += "not support."; | fail_reason += "not support."; | ||||
GELOGE(FAILED, "Check aicpu op_type failed. details: %s", fail_reason.c_str()); | GELOGE(FAILED, "Check aicpu op_type failed. details: %s", fail_reason.c_str()); | ||||
GE_MAKE_GUARD(release, callback); | |||||
GE_CHK_RT(rtStreamDestroy(stream)); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
GE_MAKE_GUARD(release, callback); | |||||
GE_CHK_RT(rtStreamDestroy(stream)); | |||||
GELOGI("Cpu kernel launch check optype task success."); | GELOGI("Cpu kernel launch check optype task success."); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -273,7 +273,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { | |||||
std::shared_ptr<hybrid::HybridDavinciModel> GetHybridModel(uint32_t id); | std::shared_ptr<hybrid::HybridDavinciModel> GetHybridModel(uint32_t id); | ||||
ge::Status KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType op_type, uint64_t session_id, uint32_t model_id); | |||||
ge::Status KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType op_type, uint64_t session_id, uint32_t model_id, | |||||
uint32_t sub_model_id); | |||||
ge::Status CreateAicpuSession(uint64_t session_id); | ge::Status CreateAicpuSession(uint64_t session_id); | ||||
@@ -281,9 +282,9 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { | |||||
void DestroyAicpuSession(uint64_t session_id); | void DestroyAicpuSession(uint64_t session_id); | ||||
ge::Status DestroyAicpuKernel(uint64_t session_id, uint32_t model_id); | |||||
ge::Status DestroyAicpuKernel(uint64_t session_id, uint32_t model_id, uint32_t sub_model_id); | |||||
ge::Status CreateAicpuKernel(uint64_t session_id, uint32_t model_id, uint64_t kernel_id); | |||||
ge::Status CreateAicpuKernel(uint64_t session_id, uint32_t model_id, uint32_t sub_model_id, uint64_t kernel_id); | |||||
ge::Status DestroyAicpuSessionForInfer(uint32_t model_id); | ge::Status DestroyAicpuSessionForInfer(uint32_t model_id); | ||||
@@ -97,14 +97,16 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin | |||||
// 2.2 Collect aicpu kernel | // 2.2 Collect aicpu kernel | ||||
uint64_t kernel_id = fwk_op_kernel.fwkKernelBase.fwk_kernel.kernelID; | uint64_t kernel_id = fwk_op_kernel.fwkKernelBase.fwk_kernel.kernelID; | ||||
GE_IF_BOOL_EXEC(ModelManager::GetInstance()->CreateAicpuKernel(session_id, davinci_model->Id(), kernel_id) != SUCCESS, | |||||
GE_IF_BOOL_EXEC(ModelManager::GetInstance()->CreateAicpuKernel(session_id, davinci_model->Id(), | |||||
davinci_model->SubModelId(), kernel_id) != SUCCESS, | |||||
GELOGE(FAILED, "CreateAicpuKernel error."); | GELOGE(FAILED, "CreateAicpuKernel error."); | ||||
return FAILED;) | return FAILED;) | ||||
// 2.3 Create session | // 2.3 Create session | ||||
GE_CHECK_NOTNULL(ModelManager::GetInstance()); | GE_CHECK_NOTNULL(ModelManager::GetInstance()); | ||||
GE_IF_BOOL_EXEC(ModelManager::GetInstance()->CreateAicpuSession(session_id) != SUCCESS, | |||||
GELOGE(FAILED, "CreateAicpuSession error. session id: %lu", session_id); | |||||
return FAILED;) | |||||
ret = ModelManager::GetInstance()->CreateAicpuSession(session_id); | |||||
GE_IF_BOOL_EXEC(ret != SUCCESS, | |||||
GELOGE(ret, "CreateAicpuSession error. session id: %lu", session_id); | |||||
return ret;) | |||||
kernel_buf_size_ = sizeof(STR_FWK_OP_KERNEL); | kernel_buf_size_ = sizeof(STR_FWK_OP_KERNEL); | ||||
if (davinci_model_->IsKnownNode()) { | if (davinci_model_->IsKnownNode()) { | ||||
@@ -132,6 +134,7 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin | |||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMemcpy error, ret: Ox%X", rt_ret); | GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMemcpy error, ret: Ox%X", rt_ret); | ||||
return RT_ERROR_TO_GE_STATUS(rt_ret);) | return RT_ERROR_TO_GE_STATUS(rt_ret);) | ||||
InitDumpTask(input_output_addr, op_desc); | |||||
GELOGI("KernelExTaskInfo knonw node Init Success."); | GELOGI("KernelExTaskInfo knonw node Init Success."); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -166,11 +169,7 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin | |||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMemcpy to input_output_addr_ error: 0x%X", rt_ret); | GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMemcpy to input_output_addr_ error: 0x%X", rt_ret); | ||||
return RT_ERROR_TO_GE_STATUS(rt_ret);) | return RT_ERROR_TO_GE_STATUS(rt_ret);) | ||||
if (davinci_model_->GetDumpProperties().IsLayerNeedDump(davinci_model_->Name(), davinci_model_->OmName(), | |||||
op_desc->GetName())) { | |||||
dump_flag_ = RT_KERNEL_DUMPFLAG; | |||||
dump_args_ = input_output_addr_; | |||||
} | |||||
InitDumpTask(input_output_addr_, op_desc); | |||||
if (davinci_model_->GetOpDugReg()) { | if (davinci_model_->GetOpDugReg()) { | ||||
GELOGI("Op debug is open in kernel ex task info"); | GELOGI("Op debug is open in kernel ex task info"); | ||||
dump_args_ = input_output_addr_; | dump_args_ = input_output_addr_; | ||||
@@ -200,6 +199,14 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
void KernelExTaskInfo::InitDumpTask(void *addr, const OpDescPtr &op_desc) { | |||||
if (davinci_model_->GetDumpProperties().IsLayerNeedDump(davinci_model_->Name(), davinci_model_->OmName(), | |||||
op_desc->GetName())) { | |||||
dump_flag_ = RT_KERNEL_DUMPFLAG; | |||||
dump_args_ = input_output_addr_; | |||||
} | |||||
} | |||||
Status KernelExTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { | Status KernelExTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { | ||||
auto kernel_ex_def = task_def.kernel_ex(); | auto kernel_ex_def = task_def.kernel_ex(); | ||||
uint32_t op_index = kernel_ex_def.op_index(); | uint32_t op_index = kernel_ex_def.op_index(); | ||||
@@ -60,6 +60,8 @@ class KernelExTaskInfo : public TaskInfo { | |||||
private: | private: | ||||
Status CopyTaskInfo(const domi::KernelExDef &kernel_def, const RuntimeParam &rts_param, const OpDescPtr &op_desc); | Status CopyTaskInfo(const domi::KernelExDef &kernel_def, const RuntimeParam &rts_param, const OpDescPtr &op_desc); | ||||
void InitDumpTask(void *addr, const OpDescPtr &op_desc); | |||||
uint32_t task_id_; | uint32_t task_id_; | ||||
uint32_t stream_id_; | uint32_t stream_id_; | ||||
uint32_t dump_flag_; | uint32_t dump_flag_; | ||||
@@ -571,6 +571,8 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne | |||||
OpDescPtr op_desc = davinci_model_->GetOpByIndex(ctx_.opIndex); | OpDescPtr op_desc = davinci_model_->GetOpByIndex(ctx_.opIndex); | ||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
if (davinci_model_->IsKnownNode()) { | if (davinci_model_->IsKnownNode()) { | ||||
args_ = davinci_model_->GetCurrentArgsAddr(args_offset_); | |||||
InitDumpTask(offset); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -635,15 +637,7 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
skt_dump_args_ = static_cast<char *>(args_) + offset; | skt_dump_args_ = static_cast<char *>(args_) + offset; | ||||
if (davinci_model_->GetDumpProperties().IsLayerNeedDump(davinci_model_->Name(), davinci_model_->OmName(), | |||||
op_desc->GetName())) { | |||||
if (IsL1FusionOp(op_desc)) { | |||||
dump_flag_ = RT_FUSION_KERNEL_DUMPFLAG; | |||||
} else { | |||||
dump_flag_ = RT_KERNEL_DUMPFLAG; | |||||
} | |||||
dump_args_ = static_cast<char *>(args_) + offset; | |||||
} | |||||
InitDumpTask(offset); | |||||
GE_CHK_BOOL_TRUE_EXEC_INFO(davinci_model_->GetOpDugReg(), dump_args_ = static_cast<char *>(args_) + offset, | GE_CHK_BOOL_TRUE_EXEC_INFO(davinci_model_->GetOpDugReg(), dump_args_ = static_cast<char *>(args_) + offset, | ||||
"Op debug is open in TVM task info"); | "Op debug is open in TVM task info"); | ||||
@@ -941,16 +935,7 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k | |||||
GELOGE(RT_FAILED, "Call rt api(rtMemcpy) failed, ret: 0x%X", rt_ret); | GELOGE(RT_FAILED, "Call rt api(rtMemcpy) failed, ret: 0x%X", rt_ret); | ||||
return RT_ERROR_TO_GE_STATUS(rt_ret); | return RT_ERROR_TO_GE_STATUS(rt_ret); | ||||
} | } | ||||
if (davinci_model_->GetDumpProperties().IsLayerNeedDump(davinci_model_->Name(), davinci_model_->OmName(), | |||||
op_desc->GetName())) { | |||||
if (IsL1FusionOp(op_desc)) { | |||||
dump_flag_ = RT_FUSION_KERNEL_DUMPFLAG; | |||||
} else { | |||||
dump_flag_ = RT_KERNEL_DUMPFLAG; | |||||
} | |||||
dump_args_ = static_cast<char *>(args_) + sizeof(aicpu::AicpuParamHead); | |||||
} | |||||
InitDumpTask(sizeof(aicpu::AicpuParamHead)); | |||||
if (davinci_model_->GetOpDugReg()) { | if (davinci_model_->GetOpDugReg()) { | ||||
GELOGI("Op debug is open in aicpu task info"); | GELOGI("Op debug is open in aicpu task info"); | ||||
dump_args_ = static_cast<char *>(args_) + sizeof(aicpu::AicpuParamHead); | dump_args_ = static_cast<char *>(args_) + sizeof(aicpu::AicpuParamHead); | ||||
@@ -964,6 +949,18 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
void KernelTaskInfo::InitDumpTask(uint32_t offset) { | |||||
if (davinci_model_->GetDumpProperties().IsLayerNeedDump(davinci_model_->Name(), davinci_model_->OmName(), | |||||
op_desc_->GetName())) { | |||||
if (IsL1FusionOp(op_desc_)) { | |||||
dump_flag_ = RT_FUSION_KERNEL_DUMPFLAG; | |||||
} else { | |||||
dump_flag_ = RT_KERNEL_DUMPFLAG; | |||||
} | |||||
dump_args_ = static_cast<char *>(args_) + offset; | |||||
} | |||||
} | |||||
Status KernelTaskInfo::InitAicpuTaskExtInfo(const std::string &ext_info) { | Status KernelTaskInfo::InitAicpuTaskExtInfo(const std::string &ext_info) { | ||||
if (ext_info.empty()) { | if (ext_info.empty()) { | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -129,7 +129,9 @@ class KernelTaskInfo : public TaskInfo { | |||||
Status SuperKernelDistribute(); | Status SuperKernelDistribute(); | ||||
bool IsL1FusionOp(const OpDescPtr &op_desc); | bool IsL1FusionOp(const OpDescPtr &op_desc); | ||||
// For super kernel | |||||
void InitDumpTask(uint32_t offset); | |||||
// For super kernel | |||||
Status SaveSKTDumpInfo(); | Status SaveSKTDumpInfo(); | ||||
void UpdateTaskId(); | void UpdateTaskId(); | ||||
void UpdateSKTTaskId(); | void UpdateSKTTaskId(); | ||||
@@ -536,7 +536,7 @@ Status GraphManager::CopySubGraphAndMarkFusion(const ComputeGraphPtr &compute_gr | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status GraphManager::OptimizeSubGraphWithMultiThreads(ComputeGraphPtr compute_graph, | |||||
Status GraphManager::OptimizeSubGraphWithMultiThreads(ComputeGraphPtr compute_graph, | |||||
Graph2SubGraphInfoList &sub_graph_map, uint64_t session_id) { | Graph2SubGraphInfoList &sub_graph_map, uint64_t session_id) { | ||||
GE_CHECK_NOTNULL(compute_graph); | GE_CHECK_NOTNULL(compute_graph); | ||||
// use default 16 multi thread | // use default 16 multi thread | ||||
@@ -737,6 +737,9 @@ Status GraphManager::PreRunAfterOptimizeSubGraph(const GraphNodePtr &graph_node, | |||||
GeRootModelPtr &ge_root_model, uint64_t session_id) { | GeRootModelPtr &ge_root_model, uint64_t session_id) { | ||||
GE_CHECK_NOTNULL(graph_node); | GE_CHECK_NOTNULL(graph_node); | ||||
GE_CHECK_NOTNULL(compute_graph); | GE_CHECK_NOTNULL(compute_graph); | ||||
CompilerStages &stages = GetCompilerStages(graph_node->GetGraphId()); | |||||
GM_RUN_AND_DUMP_PERF("OptimizeWholeGraph", stages.optimizer.OptimizeWholeGraph, compute_graph); | |||||
GM_RUN_AND_DUMP_PERF("Optimize2", OptimizeStage2, compute_graph); | GM_RUN_AND_DUMP_PERF("Optimize2", OptimizeStage2, compute_graph); | ||||
GM_RUN_AND_DUMP_PERF("OptimizeGraphBeforeBuildForRts", | GM_RUN_AND_DUMP_PERF("OptimizeGraphBeforeBuildForRts", | ||||
GetCompilerStages(graph_node->GetGraphId()).optimizer.OptimizeGraphBeforeBuildForRts, | GetCompilerStages(graph_node->GetGraphId()).optimizer.OptimizeGraphBeforeBuildForRts, | ||||
@@ -2439,6 +2442,13 @@ Status GraphManager::CheckAndReleaseMemory(const GeModelPtr &ge_model, const Gra | |||||
continue; | continue; | ||||
} | } | ||||
auto model_id = model->GetModelId(); | auto model_id = model->GetModelId(); | ||||
// unknown model not release | |||||
bool is_unknown_shape = false; | |||||
GE_CHK_STATUS_RET(model->CheckIsUnknownShape(is_unknown_shape)); | |||||
if (is_unknown_shape) { | |||||
GELOGD("model_id[%u] graph_id[%u] is unknown model, not release memory", model_id, graph_id); | |||||
continue; | |||||
} | |||||
// not loaded,no need unload | // not loaded,no need unload | ||||
if (!it.second->GetLoadFlag()) { | if (!it.second->GetLoadFlag()) { | ||||
GELOGI("CheckAndReleaseMemory graph[%u] has not been loaded.", graph_id); | GELOGI("CheckAndReleaseMemory graph[%u] has not been loaded.", graph_id); | ||||
@@ -2456,7 +2466,7 @@ Status GraphManager::CheckAndReleaseMemory(const GeModelPtr &ge_model, const Gra | |||||
GELOGE(RT_FAILED, "[GraphManager:] rtSetDevice failed, modelId=%u, graphId=%u.", model_id, graph_id); | GELOGE(RT_FAILED, "[GraphManager:] rtSetDevice failed, modelId=%u, graphId=%u.", model_id, graph_id); | ||||
continue; | continue; | ||||
} | } | ||||
result = GraphLoader::DestroyAicpuKernel(session_id, model_id); | |||||
result = GraphLoader::DestroyAicpuKernel(session_id, model_id, 0); | |||||
if (result != SUCCESS) { | if (result != SUCCESS) { | ||||
GELOGW("[GraphManager:] destroy aicpu kernel failed when dynamic memory, modelId=%u, graphId=%u.", model_id, | GELOGW("[GraphManager:] destroy aicpu kernel failed when dynamic memory, modelId=%u, graphId=%u.", model_id, | ||||
graph_id); | graph_id); | ||||
@@ -336,4 +336,37 @@ Status GraphOptimize::IdentifyReference(ComputeGraphPtr &compute_graph) { | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status GraphOptimize::OptimizeWholeGraph(ComputeGraphPtr &compute_graph) { | |||||
if (compute_graph == nullptr) { | |||||
GELOGE(GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL, "[OptimizeWholeGraph]: compute_graph is nullptr."); | |||||
return GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL; | |||||
} | |||||
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "OptimizeWholeGraph failed."); | |||||
return GE_CLI_GE_NOT_INITIALIZED; | |||||
} | |||||
auto graph_optimizer = instance_ptr->OpsKernelManagerObj().GetAllGraphOptimizerObjsByPriority(); | |||||
GELOGI("optimize by opskernel in OptimizeWholeGraph. num of graph_optimizer is %zu.", graph_optimizer.size()); | |||||
Status ret = SUCCESS; | |||||
string exclude_core_type = (core_type_ == kVectorCore) ? kAicoreEngine : kVectorEngine; | |||||
GELOGD("[OptimizeWholeGraph]: engine type will exclude: %s", exclude_core_type.c_str()); | |||||
if (!graph_optimizer.empty()) { | |||||
for (auto &iter : graph_optimizer) { | |||||
if (iter.first == exclude_core_type || iter.second == nullptr) { | |||||
continue; | |||||
} | |||||
GELOGI("Begin to optimize whole graph by engine %s", iter.first.c_str()); | |||||
ret = iter.second->OptimizeWholeGraph(*compute_graph); | |||||
GE_DUMP(compute_graph, "OptimizeWholeGraph" + iter.first); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "[OptimizeWholeGraph]: graph optimize failed, ret:%u", ret); | |||||
return ret; | |||||
} | |||||
} | |||||
} | |||||
return ret; | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -52,6 +52,9 @@ class GraphOptimize { | |||||
// for fe prepare optimize in quantize scene | // for fe prepare optimize in quantize scene | ||||
Status OptimizeOriginalGraphForQuantize(ComputeGraphPtr &compute_graph); | Status OptimizeOriginalGraphForQuantize(ComputeGraphPtr &compute_graph); | ||||
// for engine to optimize merged whole graph before ge Optimize2 | |||||
Status OptimizeWholeGraph(ComputeGraphPtr &compute_graph); | |||||
// for rts optimize before build to add attr and insert memcpy op | // for rts optimize before build to add attr and insert memcpy op | ||||
Status OptimizeGraphBeforeBuildForRts(ComputeGraphPtr &compute_graph); | Status OptimizeGraphBeforeBuildForRts(ComputeGraphPtr &compute_graph); | ||||
@@ -22,6 +22,8 @@ | |||||
#include "graph/preprocess/multi_batch_options.h" | #include "graph/preprocess/multi_batch_options.h" | ||||
#include "graph/utils/node_utils.h" | #include "graph/utils/node_utils.h" | ||||
#include "graph/utils/op_desc_utils.h" | #include "graph/utils/op_desc_utils.h" | ||||
#include "graph/utils/tensor_utils.h" | |||||
#include "graph/utils/type_utils.h" | |||||
#include "register/op_registry.h" | #include "register/op_registry.h" | ||||
namespace ge { | namespace ge { | ||||
@@ -478,8 +480,28 @@ Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &data) { | |||||
if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) { | if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) { | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
(void)AttrUtils::SetListInt(data->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims()); | (void)AttrUtils::SetListInt(data->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims()); | ||||
GeTensorDesc tensor(NodeUtils::GetOutputDesc(*data, kDataOutIndex)); | |||||
std::vector<std::string> input_dims_str; | |||||
for (size_t i = 0; i < batch_shapes_.size(); ++i) { | |||||
auto shape = data_shape; | |||||
auto ret = multibatch::CalcShape(data_to_dynamic_info_.at(data_name).at(i), shape); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Failed to calculate the shape for data node %s, the shape may not match", data->GetName().c_str()); | |||||
return ret; | |||||
} | |||||
tensor.SetShape(shape); | |||||
int64_t tensor_size = 0; | |||||
(void)TensorUtils::GetTensorSizeInBytes(tensor, tensor_size); | |||||
string input_str = TypeUtils::FormatToSerialString(tensor.GetFormat()) + ":" + | |||||
TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + data->GetName() + ":" + | |||||
std::to_string(tensor_size) + ":" + std::to_string(tensor.GetShape().GetDimNum()) + ":" + | |||||
formats::JoinToString(tensor.GetShape().GetDims()); | |||||
input_dims_str.emplace_back(input_str); | |||||
} | |||||
(void)AttrUtils::SetListStr(data->GetOpDesc(), "_all_origin_gears_inputs", input_dims_str); | |||||
size_t max_shape_index = 0; | size_t max_shape_index = 0; | ||||
int64_t max_size = 0; | int64_t max_size = 0; | ||||
for (size_t i = 0; i < batch_shapes_.size(); ++i) { | for (size_t i = 0; i < batch_shapes_.size(); ++i) { | ||||
@@ -20,11 +20,12 @@ | |||||
#include "graph/passes/folding_pass.h" | #include "graph/passes/folding_pass.h" | ||||
namespace ge { | namespace ge { | ||||
constexpr uint32_t kDataOutIndex = 0; | |||||
constexpr uint32_t kZeroIndex = 0; | |||||
constexpr uint32_t kCaseInputBase = 1; | constexpr uint32_t kCaseInputBase = 1; | ||||
constexpr uint32_t kInvalidParent = 0x7fffffffU; | constexpr uint32_t kInvalidParent = 0x7fffffffU; | ||||
const string kMbatchNodeNameMark = "_ascend_mbatch_batch_"; | |||||
bool IsSameOpNode(const NodePtr &src_node, const NodePtr &dst_node) { | |||||
bool IsSameConstNode(const NodePtr &src_node, const NodePtr &dst_node) { | |||||
if ((src_node == nullptr) && (dst_node == nullptr)) { | if ((src_node == nullptr) && (dst_node == nullptr)) { | ||||
return true; | return true; | ||||
} | } | ||||
@@ -37,35 +38,9 @@ bool IsSameOpNode(const NodePtr &src_node, const NodePtr &dst_node) { | |||||
return false; | return false; | ||||
} | } | ||||
if ((src_node->GetInControlNodes().size() != dst_node->GetInControlNodes().size()) || | |||||
(src_node->GetOutDataNodesSize() != dst_node->GetOutDataNodesSize())) { | |||||
return false; | |||||
} | |||||
set<uint32_t> related_parent; | |||||
const auto in_nodes = src_node->GetInControlNodes(); | |||||
for (uint32_t i = 0; i < in_nodes.size(); ++i) { | |||||
const auto owner_node = in_nodes.at(i); | |||||
uint32_t parent_index = 0; | |||||
if (!AttrUtils::GetInt(owner_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | |||||
return false; | |||||
} | |||||
related_parent.insert(parent_index); | |||||
} | |||||
for (const auto &in_node : dst_node->GetInControlNodes()) { | |||||
uint32_t parent_index = 0; | |||||
if (!AttrUtils::GetInt(in_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | |||||
return false; | |||||
} | |||||
if (related_parent.count(parent_index) == 0) { | |||||
return false; | |||||
} | |||||
} | |||||
return true; | |||||
const GeTensorDesc &src_desc = src_node->GetOpDesc()->GetOutputDesc(kZeroIndex); | |||||
const GeTensorDesc &dst_desc = dst_node->GetOpDesc()->GetOutputDesc(kZeroIndex); | |||||
return (src_desc == dst_desc); | |||||
} | } | ||||
/*********************************************************************************************************************** | /*********************************************************************************************************************** | ||||
@@ -89,12 +64,12 @@ bool IsSameOpNode(const NodePtr &src_node, const NodePtr &dst_node) { | |||||
+-----------+ +-----------+ +-----------+ +-----------+ +-----------+ +-----------+ +-----------+ | +-----------+ +-----------+ +-----------+ +-----------+ +-----------+ +-----------+ +-----------+ | ||||
| Data | | Data | | Data | | Data | | Data | | Data | | Conv2D | | | Data | | Data | | Data | | Data | | Data | | Data | | Conv2D | | ||||
+-----------+ +-----------+ +-----------+ +-----------+ +-----------+ +-----------+ +-----------+ | +-----------+ +-----------+ +-----------+ +-----------+ +-----------+ +-----------+ +-----------+ | ||||
\ \ | / / | | | |||||
\ \ | / / | | | |||||
\ \ | / / | | | |||||
\ \ | / / | | | |||||
\ +-----------+ / | +-----------+ | |||||
+---------------| Const |----------------+ | | Pooling | | |||||
\ \ | / / | | +-----------+ | |||||
\ \ | / / | | | Const | | |||||
\ \ | / / | | +-----------+ | |||||
\ \ | / / | | / | |||||
\ +-----------+ / | +-----------+ / | |||||
+---------------| Const |----------------+ | | Pooling |-----+ | |||||
+-----------+ | +-----------+ | +-----------+ | +-----------+ | ||||
\ | / | \ | / | ||||
\ | / | \ | / | ||||
@@ -126,28 +101,26 @@ Status SubgraphConstMigrationPass::Run(ComputeGraphPtr graph) { | |||||
continue; | continue; | ||||
} | } | ||||
do { | |||||
migration_append_ = false; | |||||
map<ComputeGraphPtr, map<uint32_t, NodePtr>> graph_datas; | |||||
if (ClassifyDataNodes(graph, func_desc, graph_datas) != SUCCESS) { | |||||
return FAILED; | |||||
} | |||||
map<ComputeGraphPtr, map<string, NodePtr>> all_const_nodes; | |||||
map<ComputeGraphPtr, map<uint32_t, NodePtr>> all_data_nodes; | |||||
if (ClassifyGraphNodes(graph, func_desc, all_const_nodes, all_data_nodes) != SUCCESS) { | |||||
return FAILED; | |||||
} | |||||
if (graph_datas.empty()) { | |||||
GELOGW("Graph: %s subgraph is empty", graph->GetName().c_str()); | |||||
break; | |||||
} | |||||
if (all_const_nodes.empty()) { | |||||
GELOGW("Graph: %s subgraph is empty", graph->GetName().c_str()); | |||||
break; | |||||
} | |||||
// {subgraph0, {{1, Data}, {2, Data}, {3, Data}, {4, Data}, ..., {n, Data}}} | |||||
// {subgraph1, {{1, Data}, {2, Data}, {3, Data}, {4, Data}, ..., {n, Data}}} | |||||
// {subgraph2, {{1, Data}, {2, Data}, {3, Data}, {4, Data}, ..., {n, Data}}} | |||||
const auto base_nodes = graph_datas.begin()->second; // Need copy. | |||||
for (const auto &node_item : base_nodes) { | |||||
if (GraphNodeMigration(graph, node, graph_datas, node_item.second, node_item.first) != SUCCESS) { | |||||
return FAILED; | |||||
} | |||||
// {subgraph0, {{key1, Const}, {key2, Const}, {key3, Const}, {key4, Const}, ..., {keyn, Const}}} | |||||
// {subgraph1, {{key1, Const}, {key2, Const}, {key3, Const}, {key4, Const}, ..., {keyn, Const}}} | |||||
// {subgraph2, {{key1, Const}, {key2, Const}, {key3, Const}, {key4, Const}, ..., {keyn, Const}}} | |||||
const auto &const_nodes = all_const_nodes.begin()->second; | |||||
for (const auto &item : const_nodes) { | |||||
if (GraphNodeMigration(graph, node, all_const_nodes, all_data_nodes, item.second, item.first) != SUCCESS) { | |||||
return FAILED; | |||||
} | } | ||||
} while (migration_append_); | |||||
} | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -155,14 +128,16 @@ Status SubgraphConstMigrationPass::Run(ComputeGraphPtr graph) { | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief Get all Data nodes for all subgraph. | |||||
/// @brief Get all Const/Data nodes for all subgraph. | |||||
/// @param [in] graph: Root compute graph. | /// @param [in] graph: Root compute graph. | ||||
/// @param [in] func_desc: functional OpDesc of Case. | /// @param [in] func_desc: functional OpDesc of Case. | ||||
/// @param [out] graph_datas: Data groups of subgraph. | |||||
/// @param [out] all_const_nodes: Const groups of subgraph. | |||||
/// @param [out] all_data_nodes: Data groups of subgraph. | |||||
/// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
/// | /// | ||||
Status SubgraphConstMigrationPass::ClassifyDataNodes(const ComputeGraphPtr &graph, const OpDescPtr &func_desc, | |||||
map<ComputeGraphPtr, map<uint32_t, NodePtr>> &graph_datas) { | |||||
Status SubgraphConstMigrationPass::ClassifyGraphNodes(const ComputeGraphPtr &graph, const OpDescPtr &func_desc, | |||||
map<ComputeGraphPtr, map<string, NodePtr>> &all_const_nodes, | |||||
map<ComputeGraphPtr, map<uint32_t, NodePtr>> &all_data_nodes) { | |||||
for (const auto &name : func_desc->GetSubgraphInstanceNames()) { | for (const auto &name : func_desc->GetSubgraphInstanceNames()) { | ||||
const auto &subgraph = graph->GetSubgraph(name); | const auto &subgraph = graph->GetSubgraph(name); | ||||
if (subgraph == nullptr) { | if (subgraph == nullptr) { | ||||
@@ -170,32 +145,47 @@ Status SubgraphConstMigrationPass::ClassifyDataNodes(const ComputeGraphPtr &grap | |||||
return GE_GRAPH_EMPTY_SUBGRAPH; | return GE_GRAPH_EMPTY_SUBGRAPH; | ||||
} | } | ||||
auto &data_nodes = graph_datas[subgraph]; | |||||
for (auto &data : subgraph->GetDirectNode()) { | |||||
if (data->GetType() != DATA) { | |||||
continue; | |||||
} | |||||
auto &data_nodes = all_data_nodes[subgraph]; | |||||
auto &const_nodes = all_const_nodes[subgraph]; | |||||
for (auto &node : subgraph->GetDirectNode()) { | |||||
if (node->GetType() == DATA) { | |||||
uint32_t parent_index = kInvalidParent; | |||||
if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | |||||
return FAILED; | |||||
} | |||||
uint32_t parent_index = 0; | |||||
if (!AttrUtils::GetInt(data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | |||||
GELOGE(FAILED, "Parent index not found, name: %s", data->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
data_nodes[parent_index] = node; | |||||
GELOGD("%s, index: %u, Data: %s", subgraph->GetName().c_str(), parent_index, node->GetName().c_str()); | |||||
} else if ((node->GetType() == CONSTANT) && (node->GetOutDataAnchor(kZeroIndex) != nullptr)) { | |||||
set<string> peer_name_list; | |||||
const auto &out_anchor = node->GetOutDataAnchor(kZeroIndex); | |||||
for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { | |||||
const auto &peer_node = in_anchor->GetOwnerNode(); | |||||
// Trim subgraph node name prefix. | |||||
string node_full_name = peer_node->GetName(); | |||||
size_t pos = node_full_name.find(kMbatchNodeNameMark); | |||||
if (pos == string::npos) { | |||||
GELOGE(FAILED, "find: %s of multi-batch in node: %s", kMbatchNodeNameMark.c_str(), node_full_name.c_str()); | |||||
return FAILED; | |||||
} | |||||
string fixed_name = node_full_name.substr(0, pos); | |||||
pos = node_full_name.find("_", pos + kMbatchNodeNameMark.length()); | |||||
if (pos != string::npos) { | |||||
fixed_name += node_full_name.substr(pos); | |||||
} | |||||
peer_name_list.insert(fixed_name + ":" + std::to_string(in_anchor->GetIdx())); | |||||
} | |||||
data_nodes[parent_index] = data; | |||||
GELOGD("%s, Parent index: %u, Data: %s", subgraph->GetName().c_str(), parent_index, data->GetName().c_str()); | |||||
} | |||||
} | |||||
string key_of_const; | |||||
for (const string &name : peer_name_list) { | |||||
key_of_const += (key_of_const.empty() ? name : "_" + name); | |||||
} | |||||
auto iter = graph_datas.begin(); | |||||
if (iter == graph_datas.end()) { | |||||
return SUCCESS; | |||||
} | |||||
for (const auto &data_nodes : graph_datas) { | |||||
if (data_nodes.second.size() != iter->second.size()) { | |||||
GELOGE(FAILED, "Subgraph %s has invalid Data nodes[%zu != %zu]", | |||||
data_nodes.first->GetName().c_str(), data_nodes.second.size(), iter->second.size()); | |||||
return FAILED; | |||||
const_nodes[key_of_const] = node; | |||||
GELOGD("%s, Key: %s, Const: %s", subgraph->GetName().c_str(), key_of_const.c_str(), node->GetName().c_str()); | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -204,36 +194,27 @@ Status SubgraphConstMigrationPass::ClassifyDataNodes(const ComputeGraphPtr &grap | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief Get all Data nodes for all subgraph. | |||||
/// @param [in] node: Const node of subgraph. | |||||
/// @param [out] inputs: parent index to Const. | |||||
/// @param [out] outputs: Data groups of subgraph. | |||||
/// @brief Get parent_index for Const node migration. | |||||
/// @param [in] all_data_nodes: Data groups of subgraph. | |||||
/// @param [in] const_node: Const node will process. | |||||
/// @param [out] parent_index: parent index for replace Data. | |||||
/// @return true: SUCCESS / false: FAILED | /// @return true: SUCCESS / false: FAILED | ||||
/// | /// | ||||
bool SubgraphConstMigrationPass::GetAssociatedNodes(const NodePtr &node, map<uint32_t, uint32_t> &inputs, | |||||
map<uint32_t, uint32_t> &outputs) { | |||||
for (uint32_t i = 0; i < node->GetAllOutDataAnchorsSize(); ++i) { | |||||
outputs[i] = kInvalidParent; | |||||
} | |||||
uint32_t out_index = 0; | |||||
const auto in_nodes = node->GetInAllNodes(); | |||||
for (size_t i = 0; i < in_nodes.size(); ++i) { | |||||
const auto owner_node = in_nodes.at(i); | |||||
if (owner_node->GetType() != DATA) { | |||||
bool SubgraphConstMigrationPass::GetAssociatedNodes(const map<ComputeGraphPtr, map<uint32_t, NodePtr>> &all_data_nodes, | |||||
const NodePtr &const_node, uint32_t &parent_index) { | |||||
for (const auto in_node : const_node->GetInAllNodes()) { | |||||
if (in_node->GetType() != DATA) { | |||||
return false; | return false; | ||||
} | } | ||||
uint32_t parent_index = 0; | |||||
if (!AttrUtils::GetInt(owner_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | |||||
uint32_t node_index = 0; | |||||
if (!AttrUtils::GetInt(in_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, node_index)) { | |||||
return false; | return false; | ||||
} | } | ||||
// Input Data feed other Node, need add new Data. | // Input Data feed other Node, need add new Data. | ||||
inputs[i] = parent_index; | |||||
if ((out_index == outputs.size()) && owner_node->GetOutDataNodes().empty()) { | |||||
outputs[out_index] = parent_index; | |||||
++out_index; | |||||
if ((parent_index == kInvalidParent) && in_node->GetOutDataNodes().empty()) { | |||||
parent_index = node_index; | |||||
} | } | ||||
} | } | ||||
@@ -242,43 +223,26 @@ bool SubgraphConstMigrationPass::GetAssociatedNodes(const NodePtr &node, map<uin | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief Get all Data nodes for all subgraph. | |||||
/// @param [in] graph_nodes: Data groups of subgraph. | |||||
/// @param [in] data_base: Data Node for migration. | |||||
/// @param [in] data_idx: Data groups of subgraph. | |||||
/// @param [in] data_idx: Data groups of subgraph. | |||||
/// @brief Check parallel node is same for all subgraph. | |||||
/// @param [in] all_const_nodes: Const groups of subgraph. | |||||
/// @param [in] const_node: Const Node for migration. | |||||
/// @param [in] node_key: Key of Const node. | |||||
/// @return true: Same / false: not same | /// @return true: Same / false: not same | ||||
/// | /// | ||||
bool SubgraphConstMigrationPass::IsParallelNodeSame(const map<ComputeGraphPtr, map<uint32_t, NodePtr>> &graph_datas, | |||||
const NodePtr &const_node, uint32_t parent_index, size_t index) { | |||||
auto it = graph_datas.begin(); | |||||
for (++it; it != graph_datas.end(); ++it) { | |||||
const auto &data_nodes = it->second; | |||||
auto data_it = data_nodes.find(parent_index); | |||||
if (data_it == data_nodes.end()) { | |||||
GELOGE(FAILED, "Data: %s not fount, index: %u", const_node->GetName().c_str(), parent_index); | |||||
return false; | |||||
} | |||||
const auto &work_data = data_it->second; | |||||
const auto &out_anchor = work_data->GetOutControlAnchor(); | |||||
const auto &in_anchors = out_anchor->GetPeerInControlAnchors(); | |||||
if (in_anchors.size() <= index || in_anchors.at(index) == nullptr) { | |||||
GELOGW("Node anchors not same, Data: %s -> %s anchor size: %zu, index: %zu", | |||||
work_data->GetName().c_str(), const_node->GetName().c_str(), in_anchors.size(), index); | |||||
return false; | |||||
} | |||||
const auto &in_anchor = in_anchors.at(index); | |||||
const auto &work_node = in_anchor->GetOwnerNode(); | |||||
if (work_node == nullptr) { | |||||
GELOGE(FAILED, "Data: %s not found, parent: %u, index: %zu", const_node->GetName().c_str(), parent_index, index); | |||||
bool SubgraphConstMigrationPass::IsParallelNodeSame(const map<ComputeGraphPtr, map<string, NodePtr>> &all_const_nodes, | |||||
const NodePtr &const_node, const string &node_key) { | |||||
auto it = all_const_nodes.begin(); | |||||
for (++it; it != all_const_nodes.end(); ++it) { | |||||
const auto &const_nodes = it->second; | |||||
auto node_it = const_nodes.find(node_key); | |||||
if (node_it == const_nodes.end()) { | |||||
GELOGW("Const node: %s not fount, key: %s", const_node->GetName().c_str(), node_key.c_str()); | |||||
return false; | return false; | ||||
} | } | ||||
if (!IsSameOpNode(const_node, work_node)) { | |||||
GELOGI("OpDesc not same: %s %s, parent: %u, index: %zu", | |||||
const_node->GetName().c_str(), work_node->GetName().c_str(), parent_index, index); | |||||
const auto &work_node = node_it->second; | |||||
if (!IsSameConstNode(const_node, work_node)) { | |||||
GELOGI("Not same: %s %s, key: %s", const_node->GetName().c_str(), work_node->GetName().c_str(), node_key.c_str()); | |||||
return false; | return false; | ||||
} | } | ||||
} | } | ||||
@@ -291,51 +255,34 @@ bool SubgraphConstMigrationPass::IsParallelNodeSame(const map<ComputeGraphPtr, m | |||||
/// @brief Migration subgraph Node to Root | /// @brief Migration subgraph Node to Root | ||||
/// @param [in] graph: Root compute graph. | /// @param [in] graph: Root compute graph. | ||||
/// @param [in] func_node: functional Node of Case. | /// @param [in] func_node: functional Node of Case. | ||||
/// @param [in] graph_nodes: Data groups of subgraph. | |||||
/// @param [in] data_base: Data Node for migration. | |||||
/// @param [in] data_idx: Data groups of subgraph. | |||||
/// @param [in] all_const_nodes: Const groups of subgraph. | |||||
/// @param [in] all_data_nodes: Data groups of subgraph. | |||||
/// @param [in] const_node: Const Node for migration. | |||||
/// @param [in] node_key: Key of Const node for migration. | |||||
/// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
/// | /// | ||||
Status SubgraphConstMigrationPass::GraphNodeMigration(const ComputeGraphPtr &graph, const NodePtr &func_node, | Status SubgraphConstMigrationPass::GraphNodeMigration(const ComputeGraphPtr &graph, const NodePtr &func_node, | ||||
map<ComputeGraphPtr, map<uint32_t, NodePtr>> &graph_datas, | |||||
const NodePtr &data_node, uint32_t parent_index) { | |||||
bool can_extrapolation = false; | |||||
do { | |||||
can_extrapolation = false; | |||||
const auto &out_anchor = data_node->GetOutControlAnchor(); | |||||
const auto &in_anchors = out_anchor->GetPeerInControlAnchors(); | |||||
for (size_t i = in_anchors.size(); i > 0; --i) { | |||||
const auto &in_anchor = in_anchors.at(i - 1); | |||||
const auto &work_node = in_anchor->GetOwnerNode(); | |||||
GELOGD("Data: %s, node: %s, parent: %u, index: %zu", | |||||
data_node->GetName().c_str(), work_node->GetName().c_str(), parent_index, i); | |||||
if (work_node->GetType() != CONSTANT) { | |||||
continue; | |||||
} | |||||
// Get associated Data, if Data feed other nodes, need append new Data. | |||||
map<uint32_t, uint32_t> inputs; | |||||
map<uint32_t, uint32_t> outputs; | |||||
if (!GetAssociatedNodes(work_node, inputs, outputs)) { | |||||
continue; | |||||
} | |||||
const map<ComputeGraphPtr, map<string, NodePtr>> &all_const_nodes, | |||||
map<ComputeGraphPtr, map<uint32_t, NodePtr>> &all_data_nodes, | |||||
const NodePtr &const_node, const string &node_key) { | |||||
if (!IsParallelNodeSame(all_const_nodes, const_node, node_key)) { | |||||
return SUCCESS; | |||||
} | |||||
if (!IsParallelNodeSame(graph_datas, work_node, parent_index, i - 1)) { | |||||
continue; | |||||
} | |||||
// Get associated Data, if Data feed other nodes, need append new Data. | |||||
uint32_t parent_index = kInvalidParent; | |||||
if (!GetAssociatedNodes(all_data_nodes, const_node, parent_index)) { | |||||
return SUCCESS; | |||||
} | |||||
GELOGI("Move node: %s, parent: %u, index: %zu", work_node->GetName().c_str(), parent_index, i); | |||||
if (AppendParallelNode(graph_datas, func_node, outputs) != SUCCESS) { | |||||
return FAILED; | |||||
} | |||||
GELOGI("Move node: %s, parent index: %u", const_node->GetName().c_str(), parent_index); | |||||
if (AppendParallelNode(func_node, parent_index, all_data_nodes) != SUCCESS) { | |||||
return FAILED; | |||||
} | |||||
if (MoveNodeToParent(graph, func_node, graph_datas, parent_index, i - 1, inputs, outputs) != SUCCESS) { | |||||
return FAILED; | |||||
} | |||||
can_extrapolation = true; | |||||
break; | |||||
} | |||||
} while (can_extrapolation); | |||||
if (MoveNodeToParent(graph, func_node, all_const_nodes, all_data_nodes, node_key, parent_index) != SUCCESS) { | |||||
return FAILED; | |||||
} | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -343,114 +290,100 @@ Status SubgraphConstMigrationPass::GraphNodeMigration(const ComputeGraphPtr &gra | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief Append Input Tensor for functional node. | /// @brief Append Input Tensor for functional node. | ||||
/// @param [in] graph_nodes: Data groups of subgraph. | |||||
/// @param [in] func_node: functional Node of Case. | /// @param [in] func_node: functional Node of Case. | ||||
/// @param [in] outputs: Parent index of Node output. | |||||
/// @param [in/out] parent_index: Parent index for migration. | |||||
/// @param [in/out] all_data_nodes: Data groups of subgraph. | |||||
/// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
/// | /// | ||||
Status SubgraphConstMigrationPass::AppendParallelNode(map<ComputeGraphPtr, map<uint32_t, NodePtr>> &graph_datas, | |||||
const NodePtr &func_node, map<uint32_t, uint32_t> &outputs) { | |||||
Status SubgraphConstMigrationPass::AppendParallelNode(const NodePtr &func_node, uint32_t &parent_index, | |||||
map<ComputeGraphPtr, map<uint32_t, NodePtr>> &all_data_nodes) { | |||||
// If outputs index invalid, add Data and Input Tensor. | // If outputs index invalid, add Data and Input Tensor. | ||||
for (auto &item : outputs) { | |||||
if (item.second != kInvalidParent) { | |||||
continue; | |||||
} | |||||
// Add Data to subgraph. | |||||
map<ComputeGraphPtr, uint32_t> append_num; | |||||
for (auto &groups : graph_datas) { | |||||
const auto &subgraph = groups.first; | |||||
auto &data_nodes = groups.second; | |||||
item.second = func_node->GetAllInDataAnchorsSize() + append_num[subgraph]; // Update to valid parent index. | |||||
const auto data_name = subgraph->GetName() + "_data_" + std::to_string(item.second); | |||||
OpDescBuilder op_builder(data_name, DATA); | |||||
const OpDescPtr op_desc = op_builder.AddInput("x").AddOutput("y").Build(); | |||||
if (op_desc == nullptr) { | |||||
GELOGE(OUT_OF_MEMORY, "Create multi-batch subgraph data desc failed"); | |||||
return OUT_OF_MEMORY; | |||||
} | |||||
if (parent_index != kInvalidParent) { | |||||
return SUCCESS; | |||||
} | |||||
uint32_t data_index = item.second - kCaseInputBase; | |||||
if (!AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, data_index)) { | |||||
GELOGE(FAILED, "Parent index not found, name: %s", op_desc->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
// Add Data to subgraph. | |||||
parent_index = func_node->GetAllInDataAnchorsSize(); // Update to valid parent index. | |||||
for (auto &item : all_data_nodes) { | |||||
const auto &subgraph = item.first; | |||||
const auto data_name = subgraph->GetName() + "_data_" + std::to_string(parent_index); | |||||
OpDescBuilder op_builder(data_name, DATA); | |||||
const auto op_desc = op_builder.AddInput("x").AddOutput("y").Build(); | |||||
if (op_desc == nullptr) { | |||||
GELOGE(OUT_OF_MEMORY, "Create multi-batch subgraph data desc failed"); | |||||
return OUT_OF_MEMORY; | |||||
} | |||||
if (!AttrUtils::SetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, item.second)) { | |||||
GELOGE(FAILED, "Parent index not found, name: %s", op_desc->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
uint32_t data_index = parent_index - kCaseInputBase; | |||||
if (!AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, data_index)) { | |||||
GELOGE(FAILED, "Parent index not found, name: %s", op_desc->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
append_num[subgraph]++; | |||||
data_nodes[item.second] = subgraph->AddNode(op_desc); | |||||
GELOGI("Add Node: %s, parent index: %u", op_desc->GetName().c_str(), item.second); | |||||
if (!AttrUtils::SetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | |||||
GELOGE(FAILED, "Parent index not found, name: %s", op_desc->GetName().c_str()); | |||||
return FAILED; | |||||
} | } | ||||
// Add InputTensor to functional Node. | |||||
NodeUtils::AppendInputAnchor(func_node, item.second + 1); | |||||
item.second[parent_index] = subgraph->AddNode(op_desc); | |||||
GELOGI("Add Node: %s, parent index: %u", op_desc->GetName().c_str(), parent_index); | |||||
} | } | ||||
// Add InputTensor to functional Node. | |||||
NodeUtils::AppendInputAnchor(func_node, parent_index + 1); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief Delete Node from all subgraph. | |||||
/// @param [in] graph_nodes: Data groups of subgraph. | |||||
/// @param [in] detach: Node will move to parent. | |||||
/// @param [in] outputs: Parent index of Node output. | |||||
/// @brief Delete Node from subgraph. | |||||
/// @param [in] graph: subgraph for process. | |||||
/// @param [in] const_node: Node will move to parent. | |||||
/// @param [in] data_node: Place holder for Const. | |||||
/// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
/// | /// | ||||
Status SubgraphConstMigrationPass::DetachParallelNode(const map<uint32_t, NodePtr> &graph_datas, const NodePtr &detach, | |||||
const map<uint32_t, uint32_t> &outputs) { | |||||
Status SubgraphConstMigrationPass::DetachParallelNode(const ComputeGraphPtr &graph, const NodePtr &const_node, | |||||
const NodePtr &data_node) { | |||||
// Break Data and Move node. | // Break Data and Move node. | ||||
const auto &in_anchor = detach->GetInControlAnchor(); | |||||
const auto &out_anchors = in_anchor->GetPeerOutControlAnchors(); | |||||
for (size_t i = out_anchors.size(); i > 0; --i) { | |||||
const auto &out_anchor = out_anchors.at(i - 1); | |||||
const auto &in_anchor = const_node->GetInControlAnchor(); | |||||
const auto out_anchors = in_anchor->GetPeerOutControlAnchors(); | |||||
for (const auto out_anchor : out_anchors) { | |||||
GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, in_anchor), "Remove edge failed"); | GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, in_anchor), "Remove edge failed"); | ||||
const auto &owner_node = out_anchor->GetOwnerNode(); | |||||
GELOGI("Remove Edge: %s %s", owner_node->GetName().c_str(), detach->GetName().c_str()); | |||||
} | |||||
// Break Move and follow, Link Data and follow. | |||||
for (uint32_t i = 0; i < detach->GetAllOutDataAnchorsSize(); ++i) { | |||||
auto it_idx = outputs.find(i); | |||||
if (it_idx == outputs.end()) { | |||||
GELOGE(FAILED, "Node: %s parent index %u not found", detach->GetName().c_str(), i); | |||||
return FAILED; | |||||
} | |||||
auto it_data = graph_datas.find(it_idx->second); | |||||
if (it_data == graph_datas.end()) { | |||||
GELOGE(FAILED, "Node: %s parent index %u not found", detach->GetName().c_str(), i); | |||||
return FAILED; | |||||
const auto owner_node = out_anchor->GetOwnerNode(); | |||||
GELOGI("Remove Edge: %s %s", owner_node->GetName().c_str(), const_node->GetName().c_str()); | |||||
if (owner_node->GetInAllNodes().empty() && owner_node->GetOutAllNodes().empty() && owner_node != data_node) { | |||||
graph->RemoveNode(owner_node); | |||||
} | } | ||||
} | |||||
const auto &data_node = it_data->second; | |||||
const auto &out_anchor = detach->GetOutDataAnchor(i); | |||||
const auto &ctrl_anchor = const_node->GetOutControlAnchor(); | |||||
const auto ctrl_anchors = ctrl_anchor->GetPeerInControlAnchors(); | |||||
for (const auto in_anchor : ctrl_anchors) { | |||||
GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(ctrl_anchor, in_anchor), "Remove edge failed"); | |||||
GELOGI("Remove Edge: %s %s", const_node->GetName().c_str(), in_anchor->GetOwnerNode()->GetName().c_str()); | |||||
const auto &out_desc = detach->GetOpDesc()->GetOutputDesc(i); | |||||
const auto &data_desc = data_node->GetOpDesc(); | |||||
(void)data_desc->UpdateInputDesc(kDataOutIndex, out_desc); // Set Data Input to new connect Node. | |||||
(void)data_desc->UpdateOutputDesc(kDataOutIndex, out_desc); // Set Data Output to new connect Node. | |||||
GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(data_node->GetOutControlAnchor(), in_anchor), "Add edge failed"); | |||||
GELOGI("Add Edge: %s %s", data_node->GetName().c_str(), in_anchor->GetOwnerNode()->GetName().c_str()); | |||||
} | |||||
for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { | |||||
if (in_anchor == nullptr) { | |||||
continue; | |||||
} | |||||
GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, in_anchor), "Remove edge failed"); | |||||
const auto &owner_node = in_anchor->GetOwnerNode(); | |||||
GELOGI("Remove Edge: %s %s", detach->GetName().c_str(), owner_node->GetName().c_str()); | |||||
// Break Move and follow, Link Data and follow. | |||||
const auto &out_anchor = const_node->GetOutDataAnchor(kZeroIndex); | |||||
const auto in_anchors =out_anchor->GetPeerInDataAnchors(); | |||||
for (const auto in_anchor : in_anchors) { | |||||
GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, in_anchor), "Remove edge failed"); | |||||
GELOGI("Remove Edge: %s %s", const_node->GetName().c_str(), in_anchor->GetOwnerNode()->GetName().c_str()); | |||||
const auto &data_out_anchor = data_node->GetOutDataAnchor(kDataOutIndex); | |||||
GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(data_out_anchor, in_anchor), "Add edge failed"); | |||||
GELOGI("Add Edge: %s %s", data_node->GetName().c_str(), owner_node->GetName().c_str()); | |||||
} | |||||
GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(data_node->GetOutDataAnchor(kZeroIndex), in_anchor), "Add edge failed"); | |||||
GELOGI("Add Edge: %s %s", data_node->GetName().c_str(), in_anchor->GetOwnerNode()->GetName().c_str()); | |||||
} | } | ||||
// Update Data op DataType. | |||||
const auto &const_desc = const_node->GetOpDesc(); | |||||
const auto &tensor_desc = const_desc->GetOutputDesc(kZeroIndex); | |||||
const auto &data_desc = data_node->GetOpDesc(); | |||||
(void)data_desc->UpdateInputDesc(kZeroIndex, tensor_desc); // Set Data Input to new connect Node. | |||||
(void)data_desc->UpdateOutputDesc(kZeroIndex, tensor_desc); // Set Data Output to new connect Node. | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -459,47 +392,37 @@ Status SubgraphConstMigrationPass::DetachParallelNode(const map<uint32_t, NodePt | |||||
/// @brief Move Node to Parent Graph. | /// @brief Move Node to Parent Graph. | ||||
/// @param [in] graph: Parent compute graph. | /// @param [in] graph: Parent compute graph. | ||||
/// @param [in] func_node: functional Node of Case. | /// @param [in] func_node: functional Node of Case. | ||||
/// @param [in] attach: Node will move to parent. | |||||
/// @param [in] inputs: Parent index of Node input. | |||||
/// @param [in] outputs: Parent index of Node output. | |||||
/// @param [in] const_node: Node will move to parent. | |||||
/// @param [in] parent_index: Parent index of Node input. | |||||
/// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
/// | /// | ||||
Status SubgraphConstMigrationPass::AttachParallelNode(const ComputeGraphPtr &graph, const NodePtr &func_node, | Status SubgraphConstMigrationPass::AttachParallelNode(const ComputeGraphPtr &graph, const NodePtr &func_node, | ||||
const NodePtr &attach, const map<uint32_t, uint32_t> &inputs, | |||||
const map<uint32_t, uint32_t> &outputs) { | |||||
GE_CHECK_NOTNULL(attach); | |||||
for (const auto item : inputs) { | |||||
if (item.second == kInvalidParent) { // Not connect, Skip. | |||||
continue; | |||||
} | |||||
const auto &in_anchor = func_node->GetInDataAnchor(item.second); | |||||
const auto &out_anchor = in_anchor->GetPeerOutAnchor(); | |||||
const auto &owner_node = out_anchor->GetOwnerNode(); | |||||
const auto &in_control = attach->GetInControlAnchor(); | |||||
GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(owner_node->GetOutControlAnchor(), in_control), "Add edge failed"); | |||||
GELOGI("Add Edge: %s %s", owner_node->GetName().c_str(), attach->GetName().c_str()); | |||||
const NodePtr &const_node, uint32_t parent_index) { | |||||
GE_CHECK_NOTNULL(const_node); | |||||
if (parent_index == kInvalidParent) { | |||||
return INTERNAL_ERROR; | |||||
} | } | ||||
for (const auto &item : outputs) { | |||||
const auto &func_desc = func_node->GetOpDesc(); | |||||
const auto &out_desc = attach->GetOpDesc()->GetOutputDesc(item.second); | |||||
(void)func_desc->UpdateInputDesc(item.second, out_desc); // Set Data Input to new connect Node. | |||||
const auto &in_anchor = func_node->GetInDataAnchor(item.second); | |||||
const auto &out_anchor = in_anchor->GetPeerOutAnchor(); | |||||
if (out_anchor != nullptr) { | |||||
GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, in_anchor), "Remove edge failed"); | |||||
const auto &owner_node = out_anchor->GetOwnerNode(); | |||||
GELOGI("Remove Edge: %s %s", owner_node->GetName().c_str(), func_node->GetName().c_str()); | |||||
const auto &func_desc = func_node->GetOpDesc(); | |||||
const auto &tensor_desc = const_node->GetOpDesc()->GetOutputDesc(kZeroIndex); | |||||
(void)func_desc->UpdateInputDesc(parent_index, tensor_desc); // Set Data Input to new connect Node. | |||||
const auto &in_anchor = func_node->GetInDataAnchor(parent_index); | |||||
const auto &out_anchor = in_anchor->GetPeerOutAnchor(); | |||||
if (out_anchor != nullptr) { // Break useless old link. | |||||
GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, in_anchor), "Remove edge failed"); | |||||
const auto owner_node = out_anchor->GetOwnerNode(); | |||||
GELOGI("Remove Edge: %s %s", owner_node->GetName().c_str(), func_node->GetName().c_str()); | |||||
if (owner_node->GetInAllNodes().empty() && owner_node->GetOutAllNodes().empty()) { | |||||
graph->RemoveNode(owner_node); | |||||
} | } | ||||
GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(attach->GetOutDataAnchor(item.first), in_anchor), "Add edge failed"); | |||||
GELOGI("Add Edge: %s %s", attach->GetName().c_str(), func_node->GetName().c_str()); | |||||
} | } | ||||
GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(const_node->GetOutDataAnchor(kZeroIndex), in_anchor), "Add edge failed"); | |||||
GELOGI("Add Edge: %s %s, index: %u", const_node->GetName().c_str(), func_node->GetName().c_str(), parent_index); | |||||
(void)graph->AddNode(attach); | |||||
(void)attach->SetOwnerComputeGraph(graph); | |||||
GELOGI("Add Node: %s %s", graph->GetName().c_str(), attach->GetName().c_str()); | |||||
(void)graph->AddNode(const_node); | |||||
(void)const_node->SetOwnerComputeGraph(graph); | |||||
GELOGI("Add Node: %s %s", graph->GetName().c_str(), const_node->GetName().c_str()); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -515,43 +438,37 @@ Status SubgraphConstMigrationPass::AttachParallelNode(const ComputeGraphPtr &gra | |||||
/// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
/// | /// | ||||
Status SubgraphConstMigrationPass::MoveNodeToParent(const ComputeGraphPtr &graph, const NodePtr &func_node, | Status SubgraphConstMigrationPass::MoveNodeToParent(const ComputeGraphPtr &graph, const NodePtr &func_node, | ||||
const map<ComputeGraphPtr, map<uint32_t, NodePtr>> &graph_datas, | |||||
uint32_t parent_index, uint32_t index, | |||||
const map<uint32_t, uint32_t> &inputs, | |||||
const map<uint32_t, uint32_t> &outputs) { | |||||
if (inputs.empty()) { | |||||
const map<ComputeGraphPtr, map<string, NodePtr>> &all_const_nodes, | |||||
const map<ComputeGraphPtr, map<uint32_t, NodePtr>> &all_data_nodes, | |||||
const string &node_key, uint32_t parent_index) { | |||||
if (node_key.empty() || parent_index == kInvalidParent) { | |||||
GELOGE(FAILED, "Graph: %s, inputs is empty", graph->GetName().c_str()); | GELOGE(FAILED, "Graph: %s, inputs is empty", graph->GetName().c_str()); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
NodePtr move_node; | NodePtr move_node; | ||||
for (auto &groups : graph_datas) { | |||||
const auto &subgraph = groups.first; | |||||
const auto &data_nodes = groups.second; | |||||
auto it = data_nodes.find(parent_index); | |||||
if (it == data_nodes.end()) { | |||||
GELOGE(FAILED, "Graph: %s, Data: %u node not found", subgraph->GetName().c_str(), parent_index); | |||||
for (auto &item : all_const_nodes) { | |||||
const auto &subgraph = item.first; | |||||
const auto it_const = item.second.find(node_key); | |||||
if (it_const == item.second.end()) { | |||||
GELOGE(FAILED, "Graph: %s, Const: %s node not found", subgraph->GetName().c_str(), node_key.c_str()); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
move_node = it_const->second; | |||||
const auto &base_data = it->second; | |||||
const auto &out_anchor = base_data->GetOutControlAnchor(); | |||||
const auto &in_anchors = out_anchor->GetPeerInControlAnchors(); | |||||
if (in_anchors.size() <= index || in_anchors.at(index) == nullptr) { | |||||
GELOGE(FAILED, "Data: %s, anchor size: %zu, index: %u not found", | |||||
base_data->GetName().c_str(), in_anchors.size(), index); | |||||
const auto it_nodes = all_data_nodes.find(subgraph); | |||||
if (it_nodes == all_data_nodes.end()) { | |||||
GELOGE(FAILED, "Graph: %s, Const: %s node not found", subgraph->GetName().c_str(), node_key.c_str()); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
const auto &in_anchor = in_anchors.at(index); | |||||
move_node = in_anchor->GetOwnerNode(); | |||||
if (move_node == nullptr) { | |||||
GELOGE(FAILED, "Data: %s not found, index: %u", base_data->GetName().c_str(), parent_index); | |||||
const auto it_data = it_nodes->second.find(parent_index); | |||||
if (it_data == it_nodes->second.end()) { | |||||
GELOGE(FAILED, "Graph: %s, Const: %s node not found", subgraph->GetName().c_str(), node_key.c_str()); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
if (DetachParallelNode(data_nodes, move_node, outputs) != SUCCESS) { | |||||
GELOGE(FAILED, "Data: %s not found, index: %u", base_data->GetName().c_str(), parent_index); | |||||
if (DetachParallelNode(subgraph, move_node, it_data->second) != SUCCESS) { | |||||
GELOGE(FAILED, "Data: %s not found, index: %u", move_node->GetName().c_str(), parent_index); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
@@ -559,11 +476,10 @@ Status SubgraphConstMigrationPass::MoveNodeToParent(const ComputeGraphPtr &graph | |||||
GELOGI("Remove Node: %s %s", subgraph->GetName().c_str(), move_node->GetName().c_str()); | GELOGI("Remove Node: %s %s", subgraph->GetName().c_str(), move_node->GetName().c_str()); | ||||
} | } | ||||
if (AttachParallelNode(graph, func_node, move_node, inputs, outputs) != SUCCESS) { | |||||
if (AttachParallelNode(graph, func_node, move_node, parent_index) != SUCCESS) { | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
migration_append_ = true; | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
} // namespace ge | } // namespace ge |
@@ -36,50 +36,54 @@ class SubgraphConstMigrationPass : public GraphPass { | |||||
private: | private: | ||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief Get all Data nodes for all subgraph. | |||||
/// @brief Get all Const/Data nodes for all subgraph. | |||||
/// @param [in] graph: Root compute graph. | /// @param [in] graph: Root compute graph. | ||||
/// @param [in] func_desc: functional OpDesc of Case. | /// @param [in] func_desc: functional OpDesc of Case. | ||||
/// @param [out] graph_datas: Data groups of subgraph. | |||||
/// @param [out] all_const_nodes: Const groups of subgraph. | |||||
/// @param [out] all_data_nodes: Data groups of subgraph. | |||||
/// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
/// | /// | ||||
Status ClassifyDataNodes(const ComputeGraphPtr &graph, const OpDescPtr &func_desc, | |||||
map<ComputeGraphPtr, map<uint32_t, NodePtr>> &graph_datas); | |||||
Status ClassifyGraphNodes(const ComputeGraphPtr &graph, const OpDescPtr &func_desc, | |||||
map<ComputeGraphPtr, map<string, NodePtr>> &all_const_nodes, | |||||
map<ComputeGraphPtr, map<uint32_t, NodePtr>> &all_data_nodes); | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief Get all Data nodes for all subgraph. | |||||
/// @param [in] node: Const node of subgraph. | |||||
/// @param [in] func_desc: functional OpDesc of Case. | |||||
/// @param [out] graph_nodes: Data groups of subgraph. | |||||
/// @brief Get parent_index for Const node migration. | |||||
/// @param [in] all_data_nodes: Data groups of subgraph. | |||||
/// @param [in] const_node: Const node will process. | |||||
/// @param [out] parent_index: parent index for replace Data. | |||||
/// @return true: SUCCESS / false: FAILED | /// @return true: SUCCESS / false: FAILED | ||||
/// | /// | ||||
bool GetAssociatedNodes(const NodePtr &node, map<uint32_t, uint32_t> &inputs, map<uint32_t, uint32_t> &outputs); | |||||
bool GetAssociatedNodes(const map<ComputeGraphPtr, map<uint32_t, NodePtr>> &all_data_nodes, | |||||
const NodePtr &const_node, uint32_t &parent_index); | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief Get all Data nodes for all subgraph. | |||||
/// @param [in] graph_nodes: Data groups of subgraph. | |||||
/// @param [in] data_base: Data Node for migration. | |||||
/// @param [in] data_idx: Data groups of subgraph. | |||||
/// @param [in] data_idx: Data groups of subgraph. | |||||
/// @brief Check parallel node is same for all subgraph. | |||||
/// @param [in] all_const_nodes: Const groups of subgraph. | |||||
/// @param [in] const_node: Const Node for migration. | |||||
/// @param [in] node_key: Key of Const node. | |||||
/// @return true: Same / false: not same | /// @return true: Same / false: not same | ||||
/// | /// | ||||
bool IsParallelNodeSame(const map<ComputeGraphPtr, map<uint32_t, NodePtr>> &graph_nodes, | |||||
const NodePtr &const_node, uint32_t parent_index, size_t index); | |||||
bool IsParallelNodeSame(const map<ComputeGraphPtr, map<string, NodePtr>> &all_const_nodes, | |||||
const NodePtr &const_node, const string &node_key); | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief Migration subgraph Node to Root | /// @brief Migration subgraph Node to Root | ||||
/// @param [in] graph: Root compute graph. | /// @param [in] graph: Root compute graph. | ||||
/// @param [in] func_node: functional Node of Case. | /// @param [in] func_node: functional Node of Case. | ||||
/// @param [in] graph_nodes: Data groups of subgraph. | |||||
/// @param [in] data_base: Data Node for migration. | |||||
/// @param [in] data_idx: Data groups of subgraph. | |||||
/// @param [in] all_const_nodes: Const groups of subgraph. | |||||
/// @param [in] all_data_nodes: Data groups of subgraph. | |||||
/// @param [in] const_node: Const Node for migration. | |||||
/// @param [in] node_key: Key of Const node for migration. | |||||
/// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
/// | /// | ||||
Status GraphNodeMigration(const ComputeGraphPtr &graph, const NodePtr &func_node, | Status GraphNodeMigration(const ComputeGraphPtr &graph, const NodePtr &func_node, | ||||
map<ComputeGraphPtr, map<uint32_t, NodePtr>> &graph_nodes, | |||||
const NodePtr &data_base, uint32_t data_idx); | |||||
const map<ComputeGraphPtr, map<string, NodePtr>> &all_const_nodes, | |||||
map<ComputeGraphPtr, map<uint32_t, NodePtr>> &all_data_nodes, | |||||
const NodePtr &const_node, const string &node_key); | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
@@ -93,46 +97,42 @@ class SubgraphConstMigrationPass : public GraphPass { | |||||
/// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
/// | /// | ||||
Status MoveNodeToParent(const ComputeGraphPtr &graph, const NodePtr &func_node, | Status MoveNodeToParent(const ComputeGraphPtr &graph, const NodePtr &func_node, | ||||
const map<ComputeGraphPtr, map<uint32_t, NodePtr>> &graph_nodes, | |||||
uint32_t parent_index, uint32_t anchor_idx, | |||||
const map<uint32_t, uint32_t> &inputs, const map<uint32_t, uint32_t> &outputs); | |||||
const map<ComputeGraphPtr, map<string, NodePtr>> &all_const_nodes, | |||||
const map<ComputeGraphPtr, map<uint32_t, NodePtr>> &all_data_nodes, | |||||
const string &node_key, uint32_t parent_index); | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief Append Input Tensor for functional node. | /// @brief Append Input Tensor for functional node. | ||||
/// @param [in] graph_nodes: Data groups of subgraph. | |||||
/// @param [in] func_node: functional Node of Case. | |||||
/// @param [in] outputs: Parent index of Node output. | |||||
/// @param [in] graph_nodes: Const groups of subgraph. | |||||
/// @param [in/out] parent_index: Parent index for migration. | |||||
/// @param [in/out] all_data_nodes: Data groups of subgraph. | |||||
/// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
/// | /// | ||||
Status AppendParallelNode(map<ComputeGraphPtr, map<uint32_t, NodePtr>> &graph_nodes, | |||||
const NodePtr &func_node, map<uint32_t, uint32_t> &outputs); | |||||
Status AppendParallelNode(const NodePtr &func_node, uint32_t &parent_index, | |||||
map<ComputeGraphPtr, map<uint32_t, NodePtr>> &all_data_nodes); | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief Delete Node from all subgraph. | |||||
/// @param [in] graph_nodes: Data groups of subgraph. | |||||
/// @param [in] detach: Node will move to parent. | |||||
/// @param [in] outputs: Parent index of Node output. | |||||
/// @brief Delete Node from subgraph. | |||||
/// @param [in] graph: subgraph for process. | |||||
/// @param [in] const_node: Node will move to parent. | |||||
/// @param [in] data_node: Place holder for Const. | |||||
/// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
/// | /// | ||||
Status DetachParallelNode(const map<uint32_t, NodePtr> &graph_datas, const NodePtr &detach, | |||||
const map<uint32_t, uint32_t> &outputs); | |||||
Status DetachParallelNode(const ComputeGraphPtr &graph, const NodePtr &const_node, const NodePtr &data_node); | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief Move Node to Parent Graph. | /// @brief Move Node to Parent Graph. | ||||
/// @param [in] graph: Parent compute graph. | /// @param [in] graph: Parent compute graph. | ||||
/// @param [in] func_node: functional Node of Case. | /// @param [in] func_node: functional Node of Case. | ||||
/// @param [in] attach: Node will move to parent. | |||||
/// @param [in] inputs: Parent index of Node input. | |||||
/// @param [in] outputs: Parent index of Node output. | |||||
/// @param [in] const_node: Node will move to parent. | |||||
/// @param [in] parent_index: Parent index of Node input. | |||||
/// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
/// | /// | ||||
Status AttachParallelNode(const ComputeGraphPtr &graph, const NodePtr &func_node, const NodePtr &attach, | |||||
const map<uint32_t, uint32_t> &inputs, const map<uint32_t, uint32_t> &outputs); | |||||
bool migration_append_{false}; | |||||
Status AttachParallelNode(const ComputeGraphPtr &graph, const NodePtr &func_node, | |||||
const NodePtr &const_node, uint32_t parent_index); | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
#endif // GE_COMMON_SUBGRAPH_CONST_MIGRATION_H_ | #endif // GE_COMMON_SUBGRAPH_CONST_MIGRATION_H_ |
@@ -1646,6 +1646,10 @@ Status GraphPrepare::InferShapeForPreprocess() { | |||||
if (!options_.train_graph_flag) { | if (!options_.train_graph_flag) { | ||||
names_to_passes.emplace_back("AssertPass", &assert_pass); | names_to_passes.emplace_back("AssertPass", &assert_pass); | ||||
} | } | ||||
SwitchDeadBranchElimination switch_dead_branch_elimination; | |||||
names_to_passes.emplace_back("SwitchDeadBranchElimination", &switch_dead_branch_elimination); | |||||
MergePass merge_pass; | |||||
names_to_passes.emplace_back("MergePass", &merge_pass); | |||||
InferShapePass infer_shape_pass; | InferShapePass infer_shape_pass; | ||||
names_to_passes.emplace_back("InferShapePass", &infer_shape_pass); | names_to_passes.emplace_back("InferShapePass", &infer_shape_pass); | ||||
ReplaceWithEmptyConstPass replace_with_empty_const_pass; | ReplaceWithEmptyConstPass replace_with_empty_const_pass; | ||||
@@ -123,11 +123,22 @@ Status KnownNodeTask::Init(TaskContext &context) { | |||||
davinci_model_->GetRuntimeParam().mem_base, davinci_model_->GetRuntimeParam().mem_size); | davinci_model_->GetRuntimeParam().mem_base, davinci_model_->GetRuntimeParam().mem_size); | ||||
} | } | ||||
if (!load_flag_) { | if (!load_flag_) { | ||||
auto dump_properties = context.GetDumpProperties(); | |||||
if (dump_properties.IsDumpOpen()) { | |||||
davinci_model_->SetDumpProperties(dump_properties); | |||||
} | |||||
int32_t device_id = 0; | |||||
rtError_t rt_ret = rtGetDevice(&device_id); | |||||
if (rt_ret != RT_ERROR_NONE || device_id < 0) { | |||||
GELOGE(rt_ret, "Call rtGetDevice failed, ret = 0x%X, device_id = %d.", rt_ret, device_id); | |||||
return RT_ERROR_TO_GE_STATUS(rt_ret); | |||||
} | |||||
davinci_model_->SetDeviceId(device_id); | |||||
GE_CHK_STATUS_RET(davinci_model_->Init(), "KnownNodeExecutor::InitDavinciModel failed."); | GE_CHK_STATUS_RET(davinci_model_->Init(), "KnownNodeExecutor::InitDavinciModel failed."); | ||||
load_flag_ = true; | load_flag_ = true; | ||||
} else { | } else { | ||||
GE_CHK_STATUS_RET(ModelManager::GetInstance()->DestroyAicpuKernel(davinci_model_->GetSessionId(), | GE_CHK_STATUS_RET(ModelManager::GetInstance()->DestroyAicpuKernel(davinci_model_->GetSessionId(), | ||||
davinci_model_->Id()), "KnownNodeTask::Init destroy aicpu kernel failed."); | |||||
davinci_model_->Id(), davinci_model_->SubModelId()), "KnownNodeTask::Init destroy aicpu kernel failed."); | |||||
} | } | ||||
GELOGI("[%s] KnownNodeExecutor::Init success.", context.GetNodeName()); | GELOGI("[%s] KnownNodeExecutor::Init success.", context.GetNodeName()); | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -161,8 +172,9 @@ Status KnownNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node | |||||
// set known node flag as true | // set known node flag as true | ||||
davinci_model->SetKnownNode(true); | davinci_model->SetKnownNode(true); | ||||
davinci_model->SetId(model.GetModelId()); | |||||
// set model id as root node's node id | // set model id as root node's node id | ||||
davinci_model->SetId(node->GetOpDesc()->GetId()); | |||||
davinci_model->SetSubModelId(node->GetOpDesc()->GetId()); | |||||
GELOGD("KnownNodeExecutor::LoadTask node id %ld.", node->GetOpDesc()->GetId()); | GELOGD("KnownNodeExecutor::LoadTask node id %ld.", node->GetOpDesc()->GetId()); | ||||
GE_CHK_STATUS_RET(davinci_model->Assign(ge_model), "KnownNodeExecutor::LoadTask davincimodel assign failed."); | GE_CHK_STATUS_RET(davinci_model->Assign(ge_model), "KnownNodeExecutor::LoadTask davincimodel assign failed."); | ||||
@@ -581,42 +581,6 @@ graphStatus aclgrphGetIRVersion(int *major_version, int *minor_version, int *pat | |||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
graphStatus aclgrphInferShapeAndType(ge::Graph &graph) { | |||||
auto compute_graph = GraphUtils::GetComputeGraph(graph); | |||||
GE_CHECK_NOTNULL(compute_graph); | |||||
auto root_graph = compute_graph->GetParentGraph(); | |||||
if (root_graph != nullptr) { | |||||
GELOGE(GRAPH_PARAM_INVALID, "Input param should not be subgraph"); | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
auto ret = compute_graph->TopologicalSorting(); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
GELOGE(ret, "Acl topo logical sort failed."); | |||||
return ret; | |||||
} | |||||
ret = compute_graph->InferOriginFormat(); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
GELOGE(ret, "Acl InferOriginFormat failed."); | |||||
return ret; | |||||
} | |||||
for (auto &node: compute_graph->GetAllNodes()) { | |||||
graphStatus ret = ShapeRefiner::InferShapeAndType(node); | |||||
if (ret == GRAPH_PARAM_INVALID) { | |||||
GELOGW("Can not find infershape func."); | |||||
continue; | |||||
} else if (ret != GRAPH_SUCCESS) { | |||||
GELOGE(ret, "Acl infershape failed."); | |||||
return ret; | |||||
} | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
graphStatus aclgrphDumpGraph(const ge::Graph &graph, const char *file, const size_t len) { | graphStatus aclgrphDumpGraph(const ge::Graph &graph, const char *file, const size_t len) { | ||||
GE_CHECK_NOTNULL(file); | GE_CHECK_NOTNULL(file); | ||||
@@ -78,8 +78,8 @@ void CsaInteract::Init(int32_t dev_index, int64_t job_id) { | |||||
Status CsaInteract::WriteJobState(JobState job_state, JobSubState job_sub_state, uint32_t module_ret_errcode, | Status CsaInteract::WriteJobState(JobState job_state, JobSubState job_sub_state, uint32_t module_ret_errcode, | ||||
ErrorModule error_module) { | ErrorModule error_module) { | ||||
if (!is_init_) { | if (!is_init_) { | ||||
GELOGE(INTERNAL_ERROR, "CsaInteract has not init, can't WriteJobState"); | |||||
return INTERNAL_ERROR; | |||||
GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "CsaInteract has not init, can't WriteJobState"); | |||||
return ACL_ERROR_GE_INTERNAL_ERROR; | |||||
} | } | ||||
if ((curr_state_ == JOBSTATE_FAILED) || (curr_state_ == JOBSTATE_KILLED)) { | if ((curr_state_ == JOBSTATE_FAILED) || (curr_state_ == JOBSTATE_KILLED)) { | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -26,7 +26,7 @@ extern "C" { | |||||
#define ACL_PROF_ACL_API 0x0001 | #define ACL_PROF_ACL_API 0x0001 | ||||
#define ACL_PROF_TASK_TIME 0x0002 | #define ACL_PROF_TASK_TIME 0x0002 | ||||
#define ACL_PROF_AICORE_METRICS 0x0004 | #define ACL_PROF_AICORE_METRICS 0x0004 | ||||
#define ACL_PROF_AICPU_TRACE 0x0008 | |||||
#define ACL_PROF_AICPU 0x0008 | |||||
#define ACL_PROF_MAX_OP_NAME_LEN 257 | #define ACL_PROF_MAX_OP_NAME_LEN 257 | ||||
#define ACL_PROF_MAX_OP_TYPE_LEN 65 | #define ACL_PROF_MAX_OP_TYPE_LEN 65 | ||||
@@ -289,34 +289,8 @@ ACL_FUNC_VISIBILITY uint64_t aclprofGetOpDuration(const void *opInfo, size_t opI | |||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY size_t aclprofGetModelId(const void *opInfo, size_t opInfoLen, uint32_t index); | ACL_FUNC_VISIBILITY size_t aclprofGetModelId(const void *opInfo, size_t opInfoLen, uint32_t index); | ||||
/** | |||||
* @ingroup AscendCL | |||||
* @brief get cube ops from subscription data | |||||
* | |||||
* @param opInfo [IN] pointer to subscription data | |||||
* @param opInfoLen [IN] memory size of subscription data | |||||
* @param index [IN] index of op array in opInfo | |||||
* | |||||
* @retval cube ops of subscription data | |||||
* @retval 0 for failed | |||||
*/ | |||||
ACL_FUNC_VISIBILITY uint64_t aclprofGetOpCubeOps(const void *opInfo, size_t opInfoLen, uint32_t index); | |||||
/** | |||||
* @ingroup AscendCL | |||||
* @brief get vector ops from subscription data | |||||
* | |||||
* @param opInfo [IN] pointer to subscription data | |||||
* @param opInfoLen [IN] memory size of subscription data | |||||
* @param index [IN] index of op array in opInfo | |||||
* | |||||
* @retval vector ops of subscription data | |||||
* @retval 0 for failed | |||||
*/ | |||||
ACL_FUNC_VISIBILITY uint64_t aclprofGetOpVectorOps(const void *opInfo, size_t opInfoLen, uint32_t index); | |||||
#ifdef __cplusplus | #ifdef __cplusplus | ||||
} | } | ||||
#endif | #endif | ||||
#endif // INC_EXTERNAL_ACL_PROF_H_ | |||||
#endif // INC_EXTERNAL_ACL_PROF_H_ |
@@ -13,6 +13,7 @@ | |||||
* See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
* limitations under the License. | * limitations under the License. | ||||
*/ | */ | ||||
#ifndef INC_EXTERNAL_ACL_OPS_ACL_RETR_H_ | #ifndef INC_EXTERNAL_ACL_OPS_ACL_RETR_H_ | ||||
#define INC_EXTERNAL_ACL_OPS_ACL_RETR_H_ | #define INC_EXTERNAL_ACL_OPS_ACL_RETR_H_ | ||||
@@ -394,7 +394,7 @@ const std::set<std::string> ir_builder_suppported_options = {INPUT_FORMAT, | |||||
// for interface: aclgrphParse | // for interface: aclgrphParse | ||||
const std::set<std::string> ir_parser_suppported_options = { | const std::set<std::string> ir_parser_suppported_options = { | ||||
INPUT_FP16_NODES, IS_INPUT_ADJUST_HW_LAYOUT, IS_OUTPUT_ADJUST_HW_LAYOUT, OUTPUT, | INPUT_FP16_NODES, IS_INPUT_ADJUST_HW_LAYOUT, IS_OUTPUT_ADJUST_HW_LAYOUT, OUTPUT, | ||||
OUT_NODES, COMPRESS_WEIGHT_CONF, ENABLE_SCOPE_FUSION_PASSES}; | |||||
OUT_NODES, ENABLE_SCOPE_FUSION_PASSES}; | |||||
// for interface: aclgrphBuildInitialize | // for interface: aclgrphBuildInitialize | ||||
const std::set<std::string> global_options = {CORE_TYPE, | const std::set<std::string> global_options = {CORE_TYPE, | ||||
@@ -102,16 +102,6 @@ graphStatus aclgrphGetIRVersion(int *major_version, int *minor_version, int *pat | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
* @brief infer shape and data type | |||||
* | |||||
* @param graph[IN] the graph ready to build | |||||
* @retval GRAPH_SUCCESS The function is successfully executed. | |||||
* @retval OtherValues Failure | |||||
*/ | |||||
graphStatus aclgrphInferShapeAndType(ge::Graph &graph); | |||||
/** | |||||
* @ingroup AscendCL | |||||
* @brief dump graph | * @brief dump graph | ||||
* | * | ||||
* @param graph[IN] the graph ready to build | * @param graph[IN] the graph ready to build | ||||
@@ -1,11 +1,17 @@ | |||||
/** | /** | ||||
* @file rt_error_codes.h | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Copyright (C) Huawei Technologies Co., Ltd. 2019-2020. All Rights Reserved. | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | * | ||||
* This program is distributed in the hope that it will be useful, | |||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of | |||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | */ | ||||
#ifndef __INC_EXTERNEL_RT_ERROR_CODES_H__ | #ifndef __INC_EXTERNEL_RT_ERROR_CODES_H__ | ||||
@@ -28,7 +28,7 @@ | |||||
#include "ge/ge_api_error_codes.h" | #include "ge/ge_api_error_codes.h" | ||||
#if !defined(__ANDROID__) && !defined(ANDROID) | #if !defined(__ANDROID__) && !defined(ANDROID) | ||||
#define DOMI_LOGE(...) GE_LOG_ERROR(GE_MODULE_NAME, ge::FAILED, __VA_ARGS__) | |||||
#define DOMI_LOGE(fmt, ...) GE_LOG_ERROR(GE_MODULE_NAME, ge::FAILED, fmt, ##__VA_ARGS__) | |||||
#else | #else | ||||
#include <android/log.h> | #include <android/log.h> | ||||
#if defined(BUILD_VERSION_PERF) | #if defined(BUILD_VERSION_PERF) | ||||
@@ -20,7 +20,7 @@ | |||||
#include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#define GE_LOGE(...) GE_LOG_ERROR(GE_MODULE_NAME, ge::FAILED, __VA_ARGS__) | |||||
#define GE_LOGE(fmt, ...) GE_LOG_ERROR(GE_MODULE_NAME, ge::FAILED, fmt, ##__VA_ARGS__) | |||||
#define GE_LOGI_IF(condition, ...) \ | #define GE_LOGI_IF(condition, ...) \ | ||||
if ((condition)) { \ | if ((condition)) { \ | ||||
@@ -108,4 +108,5 @@ message DumpData{ | |||||
repeated OpOutput output = 3; | repeated OpOutput output = 3; | ||||
repeated OpInput input = 4; | repeated OpInput input = 4; | ||||
repeated OpBuffer buffer = 5; | repeated OpBuffer buffer = 5; | ||||
string op_name = 6; | |||||
} | } |
@@ -766,6 +766,7 @@ graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node, bool before_sub | |||||
TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), | TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), | ||||
TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); | TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); | ||||
} | } | ||||
GE_CHK_STATUS_RET_NOLOG(NodeUtils::UpdatePeerNodeInputDesc(node)); | |||||
} else { | } else { | ||||
GELOGE(GRAPH_FAILED, "%s call infer function failed.", node->GetName().c_str()); | GELOGE(GRAPH_FAILED, "%s call infer function failed.", node->GetName().c_str()); | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
@@ -318,10 +318,17 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer | |||||
TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); | TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); | ||||
for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) { | for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) { | ||||
if (peer_anchor->GetOwnerNode()->GetOpDesc() == nullptr) { | |||||
auto peer_anchor_opdesc = peer_anchor->GetOwnerNode()->GetOpDesc(); | |||||
if (peer_anchor_opdesc == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null"); | GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null"); | ||||
continue; | continue; | ||||
} | } | ||||
if (op_desc->GetId() < peer_anchor_opdesc->GetId() || | |||||
peer_anchor_opdesc->GetType() == CONSTANT || | |||||
peer_anchor_opdesc->GetType() == CONSTANTOP) { | |||||
GELOGD("no need to UpdatePeerNodeInputDesc"); | |||||
continue; | |||||
} | |||||
auto peer_input_desc = peer_anchor->GetOwnerNode()->GetOpDesc()->MutableInputDesc(peer_anchor->GetIdx()); | auto peer_input_desc = peer_anchor->GetOwnerNode()->GetOpDesc()->MutableInputDesc(peer_anchor->GetIdx()); | ||||
if (peer_input_desc == nullptr) { | if (peer_input_desc == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "peer_input_desc is nullptr"); | GELOGE(GRAPH_FAILED, "peer_input_desc is nullptr"); | ||||
@@ -337,22 +344,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer | |||||
peer_anchor->GetOwnerNode()->GetName().c_str(), peer_anchor->GetIdx(), | peer_anchor->GetOwnerNode()->GetName().c_str(), peer_anchor->GetIdx(), | ||||
TypeUtils::DataTypeToSerialString(peer_input_dtype).c_str()); | TypeUtils::DataTypeToSerialString(peer_input_dtype).c_str()); | ||||
} else if ((!peer_input_dims.empty()) && (out_dims != peer_input_dims)) { | } else if ((!peer_input_dims.empty()) && (out_dims != peer_input_dims)) { | ||||
string out_shape_str, peer_in_shape_str; | |||||
out_shape_str += "["; | |||||
for (int64_t dim : out_dims) { | |||||
out_shape_str += std::to_string(dim) + " "; | |||||
} | |||||
out_shape_str += "]"; | |||||
peer_in_shape_str += "["; | |||||
for (int64_t dim : peer_input_dims) { | |||||
peer_in_shape_str += std::to_string(dim) + " "; | |||||
} | |||||
peer_in_shape_str += "]"; | |||||
GELOGW("current node [%s] [%d]\'th out_shape is [%s].peer input node [%s] [%d]\'th " | GELOGW("current node [%s] [%d]\'th out_shape is [%s].peer input node [%s] [%d]\'th " | ||||
"input_shape is [%s].The two shape should be same! Please check graph and fix it", | "input_shape is [%s].The two shape should be same! Please check graph and fix it", | ||||
node_ptr->GetName().c_str(), out_anchor->GetIdx(), out_shape_str.c_str(), | |||||
peer_anchor->GetOwnerNode()->GetName().c_str(), peer_anchor->GetIdx(), peer_in_shape_str.c_str()); | |||||
node_ptr->GetName().c_str(), out_anchor->GetIdx(), output_tensor->GetShape().ToString().c_str(), | |||||
peer_anchor->GetOwnerNode()->GetName().c_str(), peer_anchor->GetIdx(), | |||||
peer_input_desc->GetShape().ToString().c_str()); | |||||
} | } | ||||
GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d", | GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d", | ||||
peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), | peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), | ||||
@@ -108,4 +108,5 @@ message DumpData{ | |||||
repeated OpOutput output = 3; | repeated OpOutput output = 3; | ||||
repeated OpInput input = 4; | repeated OpInput input = 4; | ||||
repeated OpBuffer buffer = 5; | repeated OpBuffer buffer = 5; | ||||
string op_name = 6; | |||||
} | } |
@@ -108,4 +108,5 @@ message DumpData{ | |||||
repeated OpOutput output = 3; | repeated OpOutput output = 3; | ||||
repeated OpInput input = 4; | repeated OpInput input = 4; | ||||
repeated OpBuffer buffer = 5; | repeated OpBuffer buffer = 5; | ||||
string op_name; | |||||
} | } |
@@ -108,4 +108,5 @@ message DumpData{ | |||||
repeated OpOutput output = 3; | repeated OpOutput output = 3; | ||||
repeated OpInput input = 4; | repeated OpInput input = 4; | ||||
repeated OpBuffer buffer = 5; | repeated OpBuffer buffer = 5; | ||||
string op_name = 6; | |||||
} | } |
@@ -13,6 +13,7 @@ | |||||
* See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
* limitations under the License. | * limitations under the License. | ||||
*/ | */ | ||||
#ifndef AICPU_ENGINE_H__ | #ifndef AICPU_ENGINE_H__ | ||||
#define AICPU_ENGINE_H__ | #define AICPU_ENGINE_H__ | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -365,6 +365,25 @@ REG_OP(BiasAddGrad) | |||||
* 4-D with shape [batch, out_height, out_width, out_channels] | * 4-D with shape [batch, out_height, out_width, out_channels] | ||||
* or [batch, out_channels, out_height, out_width]. | * or [batch, out_channels, out_height, out_width]. | ||||
* Gradients with respect to the output of the convolution. | * Gradients with respect to the output of the convolution. | ||||
*\n | |||||
*\n | |||||
* The following are the supported data types and data formats: | |||||
*@verbatim | |||||
| Tensor | out_bckprop | filter | y | |||||
------------|-------------|---------|-------- | |||||
| Data Type | float16 | float16 | float16 | |||||
| |-------------|---------|-------- | |||||
| | float32 | float32 | float32 | |||||
| |-------------|---------|-------- | |||||
| | float64 | float64 | float64 | |||||
------------|-------------|---------|-------- | |||||
| Format | NCHW | NCHW | NCHW | |||||
| | NHWC | HWCN | NHWC | |||||
@endverbatim | |||||
* For float32 and float64 type, the actual calculation on the chip is based on | |||||
* float16. | |||||
*\n | |||||
* | |||||
*@par Attributes: | *@par Attributes: | ||||
* Five attributes: | * Five attributes: | ||||
* @li strides: A tuple/list of 4 integers. The stride of the sliding window | * @li strides: A tuple/list of 4 integers. The stride of the sliding window | ||||
@@ -377,8 +396,52 @@ REG_OP(BiasAddGrad) | |||||
* channels. | * channels. | ||||
* @li data_format: An optional string from: "NHWC", "NCHW". Defaults to | * @li data_format: An optional string from: "NHWC", "NCHW". Defaults to | ||||
* "NHWC". Specify the data format of the input and output data. | * "NHWC". Specify the data format of the input and output data. | ||||
*\n | |||||
*\n | |||||
* The following value range restrictions must be met: | |||||
*@verbatim | |||||
| Name | Field | Scope | |||||
-------------------|----------|-------------- | |||||
| input_size | H | [1, 4096] | |||||
| | W | [1, 4096] | |||||
-------------------|----------|-------------- | |||||
| Filter | H | [1, 255] | |||||
| | W | [1, 255] | |||||
-------------------|----------|-------------- | |||||
| out_backprop | H | [1, 4096] | |||||
| | W | [1, 4096] | |||||
-------------------|----------|-------------- | |||||
| y(fmap) | H | [1, 4096] | |||||
| | W | [1, 4096] | |||||
-------------------|----------|-------------- | |||||
| Stride | H | [1, 63] | |||||
| | W | [1, 63] | |||||
-------------------|----------|-------------- | |||||
| Padding | Top | [0, 255] | |||||
| | Bottom | [0, 255] | |||||
| | Left | [0, 255] | |||||
| | Right | [0, 255] | |||||
-------------------|----------|-------------- | |||||
| Dilation | H | [1, 255] | |||||
| | W | [1, 255] | |||||
@endverbatim | |||||
* In Ascend910, fmap or out_backprop's H and W not support 1 when | |||||
* fmap_h + pad_top + pad_bottom != (filter_height - 1) * dilation_h + 1 | |||||
*\n | |||||
* | |||||
*@par Outputs: | *@par Outputs: | ||||
* y: A Tensor. Has the same type as filter,and has same format as input_size. | * y: A Tensor. Has the same type as filter,and has same format as input_size. | ||||
*\n | |||||
* out_backprop_height = (fmap_height + pad_top + pad_bottom - | |||||
* (dilation_h * (filter_height - 1) + 1)) | |||||
* / stride_h + 1 | |||||
*\n | |||||
* out_backprop_width = (fmap_width + pad_left + pad_right - | |||||
* (dilation_w * (filter_width - 1) + 1)) | |||||
* / stride_w + 1 | |||||
*\n | |||||
* | |||||
*@par Third-party framework compatibility | *@par Third-party framework compatibility | ||||
* Compatible with Tensorflow's conv2d_backprop_input | * Compatible with Tensorflow's conv2d_backprop_input | ||||
*/ | */ | ||||
@@ -454,6 +517,21 @@ REG_OP(Conv2DBackpropInputD) | |||||
* @li bias: An optional tensor. Must have the same type as "y". | * @li bias: An optional tensor. Must have the same type as "y". | ||||
* @li offset_w: An optional 1D tensor for quantized deconvolution. | * @li offset_w: An optional 1D tensor for quantized deconvolution. | ||||
* Type is int8. Reserved.\n | * Type is int8. Reserved.\n | ||||
*\n | |||||
*\n | |||||
* The following are the supported data types and data formats: | |||||
*@verbatim | |||||
| Tensor | x | filter | bias | y | |||||
------------|---------|---------|---------|-------- | |||||
| Data Type | float16 | float16 | float16 | float16 | |||||
| |---------|---------|---------|-------- | |||||
| | int8 | int8 | int32 | int32 | |||||
------------|---------|---------|---------|-------- | |||||
| Format | NCHW | NCHW | ND | NCHW | |||||
@endverbatim | |||||
* For int8, a dequant or requant operator must be followed. | |||||
*\n | |||||
* | |||||
*@par Attributes: | *@par Attributes: | ||||
* Six attributes: | * Six attributes: | ||||
* @li strides: A tuple or list of 2 integers. The stride of the sliding window | * @li strides: A tuple or list of 2 integers. The stride of the sliding window | ||||
@@ -468,8 +546,51 @@ REG_OP(Conv2DBackpropInputD) | |||||
Specify the data format of the input and output data. | Specify the data format of the input and output data. | ||||
* @li offset_x: An optional integer for quantized deconvolution. | * @li offset_x: An optional integer for quantized deconvolution. | ||||
* Defaults to "0". | * Defaults to "0". | ||||
*\n | |||||
*\n | |||||
* The following value range restrictions must be met: | |||||
*@verbatim | |||||
| Name | Field | Scope | |||||
-------------------|----------|-------------- | |||||
| x (out_backprop) | H | [1, 4096] | |||||
| | W | [1, 4096] | |||||
-------------------|----------|-------------- | |||||
| Filter | H | [1, 255] | |||||
| | W | [1, 255] | |||||
-------------------|----------|-------------- | |||||
| y (fmap) | H | [1, 4096] | |||||
| | W | [1, 4096] | |||||
-------------------|----------|-------------- | |||||
| Stride | H | [1, 63] | |||||
| | W | [1, 63] | |||||
-------------------|----------|-------------- | |||||
| Padding | Top | [0, 255] | |||||
| | Bottom | [0, 255] | |||||
| | Left | [0, 255] | |||||
| | Right | [0, 255] | |||||
-------------------|----------|-------------- | |||||
| Dilation | H | [1, 255] | |||||
| | W | [1, 255] | |||||
-------------------|----------|-------------- | |||||
| Offset_x | | [-128, 127] | |||||
@endverbatim | |||||
* In Ascend910, fmap or out_backprop's H and W not support 1 when | |||||
* fmap_h + pad_top + pad_bottom != (filter_height - 1) * dilation_h + 1 | |||||
*\n | |||||
* | |||||
*@par Outputs: | *@par Outputs: | ||||
* y: A Tensor. 4D tensor with shape [batch, channels, height, width]. | * y: A Tensor. 4D tensor with shape [batch, channels, height, width]. | ||||
*\n | |||||
* out_backprop_height = (fmap_height + pad_top + pad_bottom - | |||||
* (dilation_h * (filter_height - 1) + 1)) | |||||
* / stride_h + 1 | |||||
*\n | |||||
* out_backprop_width = (fmap_width + pad_left + pad_right - | |||||
* (dilation_w * (filter_width - 1) + 1)) | |||||
* / stride_w + 1 | |||||
*\n | |||||
* | |||||
* When type of x is float16, the type of y must be float16. | * When type of x is float16, the type of y must be float16. | ||||
* When type of x is int8, the type of y must be int32. | * When type of x is int8, the type of y must be int32. | ||||
*/ | */ | ||||
@@ -502,6 +623,25 @@ REG_OP(Deconvolution) | |||||
* [batch, out_height, out_width, out_channels] or [batch, out_channels, | * [batch, out_height, out_width, out_channels] or [batch, out_channels, | ||||
* out_height, out_width]. Gradients with respect to the output of the | * out_height, out_width]. Gradients with respect to the output of the | ||||
* convolution. | * convolution. | ||||
*\n | |||||
*\n | |||||
* The following are the supported data types and data formats: | |||||
*@verbatim | |||||
| Tensor | x | out_backprop | y | |||||
------------|---------|--------------|--------- | |||||
| Data Type | float16 | float16 | float16 | |||||
| |---------|--------------|--------- | |||||
| | float32 | float32 | float32 | |||||
| |---------|--------------|--------- | |||||
| | float64 | float64 | float64 | |||||
|-----------|---------|--------------|--------- | |||||
| Format | NCHW | NCHW | NCHW | |||||
| | NHWC | NHWC | HWCN | |||||
@endverbatim | |||||
* For float32 and float64 type of x and outbackprop, the actual calculation on the chip | |||||
* is based on float16. | |||||
*\n | |||||
* | |||||
*@par Attributes: | *@par Attributes: | ||||
* Five attributes: | * Five attributes: | ||||
* @li strides: A tuple/list of 4 integers. The stride of the sliding window | * @li strides: A tuple/list of 4 integers. The stride of the sliding window | ||||
@@ -514,8 +654,52 @@ REG_OP(Deconvolution) | |||||
* channels. | * channels. | ||||
* @li data_format: An optional string from: "NHWC", "NCHW". Defaults to | * @li data_format: An optional string from: "NHWC", "NCHW". Defaults to | ||||
* "NHWC". Specify the data format of the input and output data. | * "NHWC". Specify the data format of the input and output data. | ||||
*\n | |||||
*\n | |||||
* The following value range restrictions must be met: | |||||
*@verbatim | |||||
| Name | Field | Scope | |||||
-------------------|----------|-------------- | |||||
| x(fmap) | H | [1, 4096] | |||||
| | W | [1, 4096] | |||||
-------------------|----------|-------------- | |||||
| Filter Size | H | [1, 255] | |||||
| | W | [1, 255] | |||||
-------------------|----------|-------------- | |||||
| out_backprop | H | [1, 4096] | |||||
| | W | [1, 4096] | |||||
-------------------|----------|-------------- | |||||
| y | H | [1, 4096] | |||||
| | W | [1, 4096] | |||||
-------------------|----------|-------------- | |||||
| Stride | H | [1, 63] | |||||
| | W | [1, 63] | |||||
-------------------|----------|-------------- | |||||
| Padding | Top | [0, 255] | |||||
| | Bottom | [0, 255] | |||||
| | Left | [0, 255] | |||||
| | Right | [0, 255] | |||||
-------------------|----------|-------------- | |||||
| Dilation | H | [1, 255] | |||||
| | W | [1, 255] | |||||
@endverbatim | |||||
* In Ascend910, out_backprop's H and W not support 1 when | |||||
* fmap_h + pad_top + pad_bottom != (filter_height - 1) * dilation_h + 1 | |||||
*\n | |||||
* | |||||
*@par Outputs: | *@par Outputs: | ||||
* y: A Tensor. Has the same type as x, has the same format as filter_size. | * y: A Tensor. Has the same type as x, has the same format as filter_size. | ||||
*\n | |||||
* out_backprop_height = (in_height + pad_top + pad_bottom - | |||||
* (dilation_h * (filter_height - 1) + 1)) | |||||
* / stride_h + 1 | |||||
*\n | |||||
* out_backprop_width = (in_width + pad_left + pad_right - | |||||
* (dilation_w * (filter_width - 1) + 1)) | |||||
* / stride_w + 1 | |||||
*\n | |||||
* | |||||
*@par Third-party framework compatibility | *@par Third-party framework compatibility | ||||
* Compatible with Tensorflow's conv2d_backprop_filter | * Compatible with Tensorflow's conv2d_backprop_filter | ||||
*/ | */ | ||||
@@ -1031,9 +1215,7 @@ REG_OP(Conv3DBackpropInputD) | |||||
* @li c_t: A optinal Tensor dtype of float16, float32. The cell state at time t . \n | * @li c_t: A optinal Tensor dtype of float16, float32. The cell state at time t . \n | ||||
*@par Third-party framework compatibility: | *@par Third-party framework compatibility: | ||||
* Compatible with the Pytorch operator adds. | |||||
*@par Restrictions: | |||||
*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. | |||||
* Compatible with the Caffe operator LSTM. | |||||
*/ | */ | ||||
REG_OP(LSTM) | REG_OP(LSTM) | ||||
.INPUT(x, TensorType({DT_FLOAT16})) | .INPUT(x, TensorType({DT_FLOAT16})) | ||||
@@ -1275,6 +1457,22 @@ REG_OP(Conv3DTransposeD) | |||||
* or [out_channels, in_channel, filter_height, filter_width]. | * or [out_channels, in_channel, filter_height, filter_width]. | ||||
* @li bias: An optional 1D tensor of type float16 or int32. Format is "ND". | * @li bias: An optional 1D tensor of type float16 or int32. Format is "ND". | ||||
* @li offset_w: An optional 1D tensor for quantized inference. Reserved. | * @li offset_w: An optional 1D tensor for quantized inference. Reserved. | ||||
*\n | |||||
*\n | |||||
* The following are the supported data types and data formats: | |||||
*@verbatim | |||||
| Tensor | x | filter | bias | y | |||||
------------|---------|---------|---------|-------- | |||||
| Data Type | float16 | float16 | float16 | float16 | |||||
| |---------|---------|---------|-------- | |||||
| | int8 | int8 | int32 | int32 | |||||
------------|---------|---------|---------|-------- | |||||
| Format | NCHW | NCHW | ND | NCHW | |||||
| | NHWC | HWCN | | NHWC | |||||
@endverbatim | |||||
* For int8, a dequant or requant operator must be followed. | |||||
*\n | |||||
* | |||||
*@par Required Attributes: | *@par Required Attributes: | ||||
* @li strides: A required tuple/list of 4 integers. The stride of the sliding | * @li strides: A required tuple/list of 4 integers. The stride of the sliding | ||||
* window for H/W dimension. The index of H/W is same as data_format. | * window for H/W dimension. The index of H/W is same as data_format. | ||||
@@ -1293,9 +1491,55 @@ REG_OP(Conv3DTransposeD) | |||||
* to [0, 0, 0, 0]. | * to [0, 0, 0, 0]. | ||||
* @li offset_x: An optional int. Input offset, used for quantized inference. | * @li offset_x: An optional int. Input offset, used for quantized inference. | ||||
* Defaults to "0". | * Defaults to "0". | ||||
*\n | |||||
*\n | |||||
* The following value range restrictions must be met: | |||||
*@verbatim | |||||
| Name | Field | Scope | |||||
-------------------|----------|-------------- | |||||
| input_size | H | [1, 4096] | |||||
| | W | [1, 4096] | |||||
-------------------|----------|-------------- | |||||
| x (out_backprop) | H | [1, 4096] | |||||
| | W | [1, 4096] | |||||
-------------------|----------|-------------- | |||||
| filter | H | [1, 255] | |||||
| | W | [1, 255] | |||||
-------------------|----------|-------------- | |||||
| y (fmap) | H | [1, 4096] | |||||
| | W | [1, 4096] | |||||
-------------------|----------|-------------- | |||||
| Stride | H | [1, 63] | |||||
| | W | [1, 63] | |||||
-------------------|----------|-------------- | |||||
| Padding | Top | [0, 255] | |||||
| | Bottom | [0, 255] | |||||
| | Left | [0, 255] | |||||
| | Right | [0, 255] | |||||
-------------------|----------|-------------- | |||||
| Dilation | H | [1, 255] | |||||
| | W | [1, 255] | |||||
-------------------|----------|-------------- | |||||
| Offset_x | | [-128, 127] | |||||
@endverbatim | |||||
* In Ascend910, fmap or out_backprop's H and W not support 1 when | |||||
* fmap_h + pad_top + pad_bottom != (filter_height - 1) * dilation_h + 1 | |||||
*\n | |||||
* | |||||
*@par Outputs: | *@par Outputs: | ||||
* y: A Tensor. A Tensor of type float16 or int32, and has same format as | * y: A Tensor. A Tensor of type float16 or int32, and has same format as | ||||
* input_size. | * input_size. | ||||
*\n | |||||
* out_backprop_height = (fmap_height + pad_top + pad_bottom - | |||||
* (dilation_h * (filter_height - 1) + 1)) | |||||
* / stride_h + 1 | |||||
*\n | |||||
* out_backprop_width = (fmap_width + pad_left + pad_right - | |||||
* (dilation_w * (filter_width - 1) + 1)) | |||||
* / stride_w + 1 | |||||
*\n | |||||
* | |||||
*/ | */ | ||||
REG_OP(Conv2DTranspose) | REG_OP(Conv2DTranspose) | ||||
.INPUT(input_size, TensorType({DT_INT32, DT_INT64})) | .INPUT(input_size, TensorType({DT_INT32, DT_INT64})) | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1487,6 +1487,9 @@ REG_OP(DecodeBboxV2) | |||||
*@par Outputs: | *@par Outputs: | ||||
* @li y1: A Tensor. Must have the same type as x. | * @li y1: A Tensor. Must have the same type as x. | ||||
* @li y2: A Tensor. Indices of y1 in x.Dtype must be int32. | * @li y2: A Tensor. Indices of y1 in x.Dtype must be int32. | ||||
* | |||||
*@attention Constraints: | |||||
* The upper limit of data on the direction axis is 7040. | |||||
*/ | */ | ||||
REG_OP(Sort) | REG_OP(Sort) | ||||
.INPUT(x, TensorType({ DT_FLOAT16 })) | .INPUT(x, TensorType({ DT_FLOAT16 })) | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -223,6 +223,83 @@ REG_OP(AvgPool3DD) | |||||
.ATTR(data_format, String, "NDHWC") | .ATTR(data_format, String, "NDHWC") | ||||
.OP_END_FACTORY_REG(AvgPool3DD) | .OP_END_FACTORY_REG(AvgPool3DD) | ||||
/** | |||||
* @brief Computes AvgPool3DGrad function. | |||||
* @par Inputs: | |||||
* @li orig_input_shape: An NDHWC tensor of type float16, float32, or double. | |||||
* @li grads: An NDHWC tensor of type int32. | |||||
* @par Attributes: | |||||
* @li ksize: List of ints that has length 1, 3 or 5. The size of the window for each dimension of the input tensor. | |||||
* @li strides:List of ints that has length 1, 3 or 5. The stride of the sliding window for each dimension of the input tensor. | |||||
* @li pads: List of ints, implicit zero paddings on both sides of the input. | |||||
* @li ceil_mode: When true, will use ceil instead of floor in the formula to compute the output shape. | |||||
* @li count_include_pad: When true, will include the zero-padding in the averaging calculation. | |||||
* @li divisor_override: if specified, it will be used as divisor, otherwise size of the pooling region will be used. | |||||
* @li data_format: A string, format of input data. | |||||
* @par Outputs: | |||||
* @output: A mutable tensor with the same shape and type as "orig_input". | |||||
* @par Third-party framework compatibility | |||||
* @li Compatible with the TensorFlow operator AvgPoolGrad. | |||||
*/ | |||||
REG_OP(AvgPool3DGrad) | |||||
.INPUT(orig_input_shape, TensorType({DT_FLOAT16, DT_FLOAT32, DT_DOUBLE})) | |||||
.INPUT(grads, TensorType({DT_INT32})) | |||||
.OUTPUT(output, TensorType({DT_FLOAT16, DT_FLOAT32, DT_DOUBLE})) | |||||
.REQUIRED_ATTR(ksize, ListInt) | |||||
.REQUIRED_ATTR(strides, ListInt) | |||||
.REQUIRED_ATTR(pads, ListInt) | |||||
.ATTR(ceil_mode, Bool, false) | |||||
.ATTR(count_include_pad, Bool, true) | |||||
.ATTR(divisor_override, Int, 0) | |||||
.ATTR(data_format, String, "NDHWC") | |||||
.OP_END_FACTORY_REG(AvgPool3DGrad) | |||||
/** | |||||
* @brief Performs average pooling on the input. | |||||
* @par Inputs: | |||||
* @li grads: An NDHWC tensor of type float16. | |||||
* @li filter: An optional tensor of type float16, fractal_z_3d layout. | |||||
* @li multiplier: An optional tensor of float16. | |||||
* @par Attributes: | |||||
* @li orig_input_shape: List of ints that has length 5. The size of the window for each dimension of the input tensor. | |||||
* @li ksize: List of ints that has length 3. The size of the window for each dimension of the input tensor. | |||||
* @li strides:List of ints that has length 3. The stride of the sliding window for each dimension of the input tensor. | |||||
* @li pads: List of ints, implicit zero paddings on both sides of the input. | |||||
* @li ceil_mode: When true, will use ceil instead of floor in the formula to compute the output shape. | |||||
* @li count_include_pad: When true, will include the zero-padding in the averaging calculation. | |||||
* @li divisor_override: if specified, it will be used as divisor, otherwise size of the pooling region will be used. | |||||
* @li data_format: A string, format of input data . \n | |||||
* @par Outputs: | |||||
* @output: The average pooled output tensor . \n | |||||
* @attention Constraints: | |||||
* @li "ksize" is in the range [1, 255]. "strides" is in the range [1, 63] | |||||
* @par Third-party framework compatibility | |||||
* Compatible with the TensorFlow operator AvgPool3DGradD. | |||||
*/ | |||||
REG_OP(AvgPool3DGradD) | |||||
.INPUT(grads, TensorType({DT_FLOAT16})) | |||||
.OPTIONAL_INPUT(filter, TensorType({DT_FLOAT16})) | |||||
.OPTIONAL_INPUT(multiplier, TensorType({DT_FLOAT16})) | |||||
.OUTPUT(output, TensorType({DT_FLOAT16, DT_FLOAT32, DT_DOUBLE})) | |||||
.REQUIRED_ATTR(orig_input_shape, ListInt) | |||||
.REQUIRED_ATTR(ksize, ListInt) | |||||
.REQUIRED_ATTR(strides, ListInt) | |||||
.REQUIRED_ATTR(pads, ListInt) | |||||
.ATTR(ceil_mode, Bool, false) | |||||
.ATTR(count_include_pad, Bool, true) | |||||
.ATTR(divisor_override, Int, 0) | |||||
.ATTR(data_format, String, "NDHWC") | |||||
.OP_END_FACTORY_REG(AvgPool3DGradD) | |||||
/** | /** | ||||
*@brief Performs max_pool_ext2 on the input . \n | *@brief Performs max_pool_ext2 on the input . \n | ||||
@@ -350,6 +427,31 @@ REG_OP(MaxPool3D) | |||||
.ATTR(data_format, String, "NDHWC") | .ATTR(data_format, String, "NDHWC") | ||||
.OP_END_FACTORY_REG(MaxPool3D) | .OP_END_FACTORY_REG(MaxPool3D) | ||||
/** | |||||
*@brief Applies a 2D adaptive max pooling over an input signal conposed of several input planes. \n | |||||
* The output is of size H x W, for any input size. | |||||
* @par Inputs: | |||||
* One input, including: | |||||
* @li x: A Tensor. Must be one of the following data types: | |||||
* float16, float32, float64. \n | |||||
* @par Attributes: | |||||
* @li output_size: A required list of 2 ints | |||||
* specifying the size (H,W) of the output tensor. \n | |||||
* @par Outputs: | |||||
* @li y: A Tensor. Has the same data type as "x" \n | |||||
* @par Third-party framework compatibility | |||||
* Compatible with the Pytorch operator AdaptiveMaxPool2d. | |||||
*/ | |||||
REG_OP(AdaptiveMaxPool2d) | |||||
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT32, DT_DOUBLE})) | |||||
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT32, DT_DOUBLE})) | |||||
.OUTPUT(argmax, TensorType::IndexNumberType()) | |||||
.REQUIRED_ATTR(output_size, ListInt) | |||||
.OP_END_FACTORY_REG(AdaptiveMaxPool2d) | |||||
/** | /** | ||||
* @brief Computes second-order gradients of the maxpooling3d function . \n | * @brief Computes second-order gradients of the maxpooling3d function . \n | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -187,16 +187,16 @@ REG_OP(DynamicRNNGrad) | |||||
*@brief: DynamicRNN calculation. | *@brief: DynamicRNN calculation. | ||||
*@par Inputs: | *@par Inputs: | ||||
*ten inputs: | *ten inputs: | ||||
*@li x:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | |||||
*@li w:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_ZN_LSTM. | |||||
*@li b:A 1D Tensor. Must be one of the following types: float16, float32. The format must be ND. | |||||
*@li seq_length:A 1D Tensor. Must be one of the following types: int32. The format must be ND. | |||||
*@li init_h:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | |||||
*@li init_c:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | |||||
*@li wci:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_ZN_LSTM. | |||||
*@li wcf:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_ZN_LSTM. | |||||
*@li wco:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_ZN_LSTM. | |||||
*@li mask:A 1D Tensor. Must be one of the following types: uint8. The format must be ND . \n | |||||
*@li x:A required 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | |||||
*@li w:A required 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_ZN_LSTM. | |||||
*@li b:A required 1D Tensor. Must be one of the following types: float16, float32. The format must be ND. | |||||
*@li seq_length:A optional 1D Tensor. Must be one of the following types: int32. The format must be ND. | |||||
*@li init_h:A optional 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | |||||
*@li init_c:A optional 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | |||||
*@li wci:A optional 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_ZN_LSTM. | |||||
*@li wcf:A optional 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_ZN_LSTM. | |||||
*@li wco:A optional 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_ZN_LSTM. | |||||
*@li mask:A optional 1D Tensor. Must be one of the following types: uint8. The format must be ND . \n | |||||
*@par Attributes: | *@par Attributes: | ||||
*@li cell_type:An string identifying the cell type in the op. Default to "LSTM". Only LSTM is currently supported. | *@li cell_type:An string identifying the cell type in the op. Default to "LSTM". Only LSTM is currently supported. | ||||
@@ -221,6 +221,8 @@ REG_OP(DynamicRNNGrad) | |||||
*@li f:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | *@li f:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | ||||
*@li o:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | *@li o:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | ||||
*@li tanhct:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | *@li tanhct:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | ||||
*@par Third-party framework compatibility: | |||||
* Compatible with the TF operator LSTM. | |||||
*/ | */ | ||||
REG_OP(DynamicRNN) | REG_OP(DynamicRNN) | ||||
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT})) | .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT})) | ||||
@@ -258,17 +260,17 @@ REG_OP(DynamicRNN) | |||||
*@brief: DynamicLSTMV2 calculation. | *@brief: DynamicLSTMV2 calculation. | ||||
*@par Inputs: | *@par Inputs: | ||||
*ten inputs: | *ten inputs: | ||||
*@li x:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | |||||
*@li w:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_ZN_LSTM. | |||||
*@li b:A 1D Tensor. Must be one of the following types: float16, float32. The format must be ND. | |||||
*@li cont:A 1D Tensor. Must be one of the following types: float16, float32. The format must be ND. | |||||
*@li w_xc_x_static:A 1D Tensor. Must be one of the following types: float16, float32. The format must be ND. | |||||
*@li h0:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | |||||
*@li c0:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | |||||
*@li wci:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_ZN_LSTM. | |||||
*@li wcf:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_ZN_LSTM. | |||||
*@li wco:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_ZN_LSTM. | |||||
*@li mask:A 1D Tensor. Must be one of the following types: uint8. The format must be ND . \n | |||||
*@li x:A required 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | |||||
*@li w:A required 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_ZN_LSTM. | |||||
*@li b:A required 1D Tensor. Must be one of the following types: float16, float32. The format must be ND. | |||||
*@li cont:A required 2D Tensor. Must be one of the following types: float16, float32. The format must be ND. | |||||
*@li w_xc_x_static:A optional 2D Tensor. Must be one of the following types: float16, float32. The format must be ND. | |||||
*@li h0:A optional 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | |||||
*@li c0:A optional 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | |||||
*@li wci:A optional 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_ZN_LSTM. | |||||
*@li wcf:A optional 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_ZN_LSTM. | |||||
*@li wco:A optional 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_ZN_LSTM. | |||||
*@li mask:A optional 1D Tensor. Must be one of the following types: uint8. The format must be ND . | |||||
*@par Attributes: | *@par Attributes: | ||||
*@li num_output:An integer identifying the num projection in the op. Default to 0. | *@li num_output:An integer identifying the num projection in the op. Default to 0. | ||||
@@ -283,6 +285,10 @@ REG_OP(DynamicRNN) | |||||
*@li output_c:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | *@li output_c:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | ||||
*@li last_output_h:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | *@li last_output_h:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | ||||
*@li last_output_c:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | *@li last_output_c:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | ||||
*@par Third-party framework compatibility: | |||||
* Compatible with the Caffe operator LSTM. | |||||
*@par Restrictions: | |||||
* Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. | |||||
*/ | */ | ||||
REG_OP(DynamicLSTMV2) | REG_OP(DynamicLSTMV2) | ||||
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT})) | .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT})) | ||||
@@ -854,6 +860,56 @@ REG_OP(EmbeddingDenseGrad) | |||||
.ATTR(padding_idx, Int, -1) | .ATTR(padding_idx, Int, -1) | ||||
.ATTR(scale_grad_by_freq, Bool, false) | .ATTR(scale_grad_by_freq, Bool, false) | ||||
.OP_END_FACTORY_REG(EmbeddingDenseGrad) | .OP_END_FACTORY_REG(EmbeddingDenseGrad) | ||||
/** | |||||
*@brief CommonLSTM calculation. | |||||
*@par Inputs: | |||||
*eight inputs: \n | |||||
*@li x:Each time step is a 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | |||||
*@li w:Each direction is a 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_ZN_LSTM. | |||||
*@li r:Each direction is a 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_ZN_LSTM. | |||||
*@li b:An optional input. Each direction is a 1D Tensor. Must be one of the following types: float16, float32. The format must be ND. | |||||
*@li sequence_lens:An optional input. A 1D Tensor.Must be one of the following types: int32. The format must be ND. | |||||
*@li initial_h:An optional input. Each direction is a 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | |||||
*@li initial_c:An optional input. Each direction is a 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | |||||
*@li p:An optional input. Each direction is a 1D Tensor.Must be one of the following types: float16, float32. The format must be ND. | |||||
*@par Attributes: | |||||
*@li activation_alpha:Optional scaling values used by some activation functions. Empty is currently supported. | |||||
*@li activation_beta:Optional scaling values used by some activation functions. Empty is currently supported. | |||||
*@li activations:The list of activation functions. Empty is currently supported. | |||||
*@li clip:An float identifying the cell clip in the op. Default to -1. | |||||
*@li direction:Specify if the RNN is forward, reverse, or bidirectional. Must be one of forward(default), reverse, or bidirectional. | |||||
*@li hidden_size:Number of neurons in the hidden layer. Reserved. | |||||
*@li input_forget:Couple the input and forget gates if 1. Reserved. | |||||
*@par Outputs: | |||||
*three outputs: \n | |||||
*@li y:First dimension is time step, second dimension is direction, others is a 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | |||||
*@li y_h:Each direction is a 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | |||||
*@li y_c:Each direction is a 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. | |||||
*/ | |||||
REG_OP(CommonLSTM) | |||||
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT})) | |||||
.INPUT(w, TensorType({DT_FLOAT16, DT_FLOAT})) | |||||
.INPUT(r, TensorType({DT_FLOAT16, DT_FLOAT})) | |||||
.OPTIONAL_INPUT(b, TensorType({DT_FLOAT16, DT_FLOAT})) | |||||
.OPTIONAL_INPUT(sequence_lens, TensorType({DT_INT32})) | |||||
.OPTIONAL_INPUT(initial_h, TensorType({DT_FLOAT16, DT_FLOAT})) | |||||
.OPTIONAL_INPUT(initial_c, TensorType({DT_FLOAT16, DT_FLOAT})) | |||||
.OPTIONAL_INPUT(p, TensorType({DT_FLOAT16, DT_FLOAT})) | |||||
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT})) | |||||
.OUTPUT(y_h, TensorType({DT_FLOAT16, DT_FLOAT})) | |||||
.OUTPUT(y_c, TensorType({DT_FLOAT16, DT_FLOAT})) | |||||
.ATTR(activation_alpha, ListFloat, {}) | |||||
.ATTR(activation_beta, ListFloat, {}) | |||||
.ATTR(activations, ListString, {}) | |||||
.ATTR(clip, Float, -1.0) | |||||
.ATTR(direction, String, "forward") | |||||
.REQUIRED_ATTR(hidden_size, Int) | |||||
.ATTR(input_forget, Int, 0) | |||||
.OP_END_FACTORY_REG(CommonLSTM) | |||||
} // namespace ge | } // namespace ge | ||||
#endif // OPS_BUILT_IN_OP_PROTO_INC_RNN_H_ | #endif // OPS_BUILT_IN_OP_PROTO_INC_RNN_H_ |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -519,7 +519,8 @@ REG_OP(Unpack) | |||||
* @par Inputs: | * @par Inputs: | ||||
* x: A 4D Tensor with shape [batch, in_rows, in_cols, depth], Must be one of the | * x: A 4D Tensor with shape [batch, in_rows, in_cols, depth], Must be one of the | ||||
* following types:float32, double, int32, uint8, int16, int8, int64, uint16, | * following types:float32, double, int32, uint8, int16, int8, int64, uint16, | ||||
* float16, uint32, uint64 | |||||
* float16, uint32, uint64. The inputs must have data_format with one of follows: | |||||
* NHWC, NCHW. | |||||
* @par Attributes: | * @par Attributes: | ||||
* @li ksizes: A required list or tuple. The size of the sliding window for each | * @li ksizes: A required list or tuple. The size of the sliding window for each | ||||
@@ -534,7 +535,6 @@ REG_OP(Unpack) | |||||
* This is equivalent to rate in dilated (a.k.a. Atrous) convolutions. | * This is equivalent to rate in dilated (a.k.a. Atrous) convolutions. | ||||
* @li padding: A required string. The type of padding algorithm to use, | * @li padding: A required string. The type of padding algorithm to use, | ||||
support "SAME" or "VALID". \n | support "SAME" or "VALID". \n | ||||
* @li data_format: A required string. The format of input, only supported NHWC. \n | |||||
* @par Outputs: | * @par Outputs: | ||||
* y: A 4D Tensor with shape [batch, out_rows, out_cols, ksize_rows * | * y: A 4D Tensor with shape [batch, out_rows, out_cols, ksize_rows * | ||||
@@ -555,7 +555,6 @@ REG_OP(ExtractImagePatches) | |||||
.REQUIRED_ATTR(strides, ListInt) | .REQUIRED_ATTR(strides, ListInt) | ||||
.REQUIRED_ATTR(rates, ListInt) | .REQUIRED_ATTR(rates, ListInt) | ||||
.REQUIRED_ATTR(padding, String) | .REQUIRED_ATTR(padding, String) | ||||
.ATTR(data_format, String, "NHWC") | |||||
.OP_END_FACTORY_REG(ExtractImagePatches) | .OP_END_FACTORY_REG(ExtractImagePatches) | ||||
/** | /** | ||||
@@ -564,6 +563,7 @@ REG_OP(ExtractImagePatches) | |||||
* @par Inputs: | * @par Inputs: | ||||
* x: A 5D Tensor with shape [batch, in_planes, in_rows, in_cols, depth] . \n | * x: A 5D Tensor with shape [batch, in_planes, in_rows, in_cols, depth] . \n | ||||
* The inputs must have data_format with one of follows: NDHWC, NCDHW. \n | |||||
* @par Attributes: | * @par Attributes: | ||||
* @li ksizes: A required list or tuple. The size of the sliding window for each | * @li ksizes: A required list or tuple. The size of the sliding window for each | ||||
@@ -572,7 +572,6 @@ REG_OP(ExtractImagePatches) | |||||
* patches are in "x". Must be: [1, stride_planes, stride_rows, stride_cols, 1]. | * patches are in "x". Must be: [1, stride_planes, stride_rows, stride_cols, 1]. | ||||
* @li padding: A required string. The type of padding algorithm to use , | * @li padding: A required string. The type of padding algorithm to use , | ||||
* support "SAME" or "VALID" . \n | * support "SAME" or "VALID" . \n | ||||
* @li data_format: An optional string. The format of input, only supported NDHWC. \n | |||||
* @par Outputs: | * @par Outputs: | ||||
* Output: A 5D Tensor with shape [batch, out_planes, out_rows, out_cols, ksize_planes * | * Output: A 5D Tensor with shape [batch, out_planes, out_rows, out_cols, ksize_planes * | ||||
@@ -591,7 +590,6 @@ REG_OP(ExtractVolumePatches) | |||||
.REQUIRED_ATTR(ksizes, ListInt) | .REQUIRED_ATTR(ksizes, ListInt) | ||||
.REQUIRED_ATTR(strides, ListInt) | .REQUIRED_ATTR(strides, ListInt) | ||||
.REQUIRED_ATTR(padding, String) | .REQUIRED_ATTR(padding, String) | ||||
.ATTR(data_format, String, "NDHWC") | |||||
.OP_END_FACTORY_REG(ExtractVolumePatches) | .OP_END_FACTORY_REG(ExtractVolumePatches) | ||||
/** | /** | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,18 +1,18 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
* You may obtain a copy of the License at | * You may obtain a copy of the License at | ||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | * http://www.apache.org/licenses/LICENSE-2.0 | ||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | * Unless required by applicable law or agreed to in writing, software | ||||
* distributed under the License is distributed on an "AS IS" BASIS, | * distributed under the License is distributed on an "AS IS" BASIS, | ||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
* See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
* limitations under the License. | * limitations under the License. | ||||
*/ | |||||
*/ | |||||
#ifndef __CCE_RUNTIME_BASE_H__ | #ifndef __CCE_RUNTIME_BASE_H__ | ||||
#define __CCE_RUNTIME_BASE_H__ | #define __CCE_RUNTIME_BASE_H__ | ||||
@@ -81,11 +81,11 @@ typedef enum tagRtLimitType { | |||||
} rtLimitType_t; | } rtLimitType_t; | ||||
typedef struct rtExceptionInfo { | typedef struct rtExceptionInfo { | ||||
uint32_t taskid; | |||||
uint32_t streamid; | |||||
uint32_t tid; | |||||
uint32_t deviceid; | |||||
uint32_t retcode; | |||||
uint32_t taskid; | |||||
uint32_t streamid; | |||||
uint32_t tid; | |||||
uint32_t deviceid; | |||||
uint32_t retcode; | |||||
} rtExceptionInfo; | } rtExceptionInfo; | ||||
typedef void (*rtErrorCallback)(rtExceptionType); | typedef void (*rtErrorCallback)(rtExceptionType); | ||||