@@ -3,8 +3,17 @@ | |||
/output | |||
/prebuilts | |||
/cov | |||
/deps | |||
.autotools | |||
.project | |||
.cproject | |||
.settings/ | |||
/tests/frm/ | |||
*.ir | |||
*.out | |||
*.DS_Store | |||
.DS_Store | |||
server_config.sh | |||
# Dynamic libraries | |||
# *.so | |||
@@ -69,10 +69,10 @@ Status CheckOptionsValid(const std::map<string, string> &options) { | |||
auto job_id_iter = options.find(OPTION_EXEC_JOB_ID); | |||
if (job_id_iter != options.end()) { | |||
if (job_id_iter->second.length() > kMaxStrLen) { | |||
GELOGE(PARAM_INVALID,"[Check][JobId]Failed," | |||
GELOGE(PARAM_INVALID, "[Check][JobId]Failed," | |||
"the job_id [%s] string length: %zu > max string length: %d", | |||
job_id_iter->second.c_str(), job_id_iter->second.length(), kMaxStrLen); | |||
REPORT_INPUT_ERROR("E10051", std::vector<std::string>({"id","length"}), | |||
REPORT_INPUT_ERROR("E10051", std::vector<std::string>({"id", "length"}), | |||
std::vector<std::string>({job_id_iter->second, | |||
std::to_string(kMaxStrLen)})); | |||
return FAILED; | |||
@@ -244,7 +244,7 @@ std::string GEGetWarningMsg() { | |||
// Initialize session,which calls innerSession | |||
Session::Session(const std::map<string, string> &options) { | |||
ErrorManager::GetInstance().SetStage(error_message::kInitialize, error_message::kOther); | |||
GELOGT(TRACE_INIT, "Session Constructor start"); | |||
GELOGT(TRACE_INIT, "Start to construct session."); | |||
ErrorManager::GetInstance().GenWorkStreamIdDefault(); | |||
// check init status | |||
@@ -332,7 +332,7 @@ Session::Session(const std::map<AscendString, AscendString> &options) { | |||
// session destructor | |||
Session::~Session() { | |||
ErrorManager::GetInstance().SetStage(error_message::kFinalize, error_message::kFinalize); | |||
GELOGT(TRACE_INIT, "Session Destructor start"); | |||
GELOGT(TRACE_INIT, "Start to destruct session."); | |||
// 0.check init status | |||
if (!g_ge_initialized) { | |||
GELOGW("GE is not yet initialized or is finalized."); | |||
@@ -602,16 +602,16 @@ Status Session::RunGraph(uint32_t graph_id, const std::vector<Tensor> &inputs, s | |||
Status Session::RunGraphWithStreamAsync(uint32_t graph_id, void *stream, const std::vector<Tensor> &inputs, | |||
std::vector<Tensor> &outputs) { | |||
ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); | |||
GELOGT(TRACE_INIT, "Session run graph with stream async start"); | |||
GELOGT(TRACE_INIT, "Start to run graph with stream async."); | |||
ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); | |||
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||
if (instance_ptr == nullptr) { | |||
GELOGE(GE_CLI_GE_NOT_INITIALIZED, | |||
"[Run][Graph]Run graph with stream asyn failed, the GELib instance is nullptr," | |||
"[Run][Graph]Run graph with stream async failed, the GELib instance is nullptr," | |||
"session id = %lu, graph id = %u, stream = %p.", sessionId_, graph_id, stream); | |||
REPORT_INNER_ERROR("E19999", | |||
"Run graph with stream asyn failed, the GELib instance is nullptr" | |||
"Run graph with stream async failed, the GELib instance is nullptr" | |||
"session id = %lu, graph id = %u, stream = %p.", sessionId_, graph_id, stream); | |||
return FAILED; | |||
} | |||
@@ -66,21 +66,21 @@ void DumpOp::SetDynamicModelInfo(const string &dynamic_model_name, const string | |||
static void SetOpMappingLoopAddr(uintptr_t step_id, uintptr_t loop_per_iter, uintptr_t loop_cond, | |||
toolkit::aicpu::dump::OpMappingInfo &op_mapping_info) { | |||
if (step_id != 0) { | |||
GELOGI("step_id exists."); | |||
GELOGI("Exists step_id."); | |||
op_mapping_info.set_step_id_addr(static_cast<uint64_t>(step_id)); | |||
} else { | |||
GELOGI("step_id is null."); | |||
} | |||
if (loop_per_iter != 0) { | |||
GELOGI("loop_per_iter exists."); | |||
GELOGI("Exists loop_per_iter."); | |||
op_mapping_info.set_iterations_per_loop_addr(static_cast<uint64_t>(loop_per_iter)); | |||
} else { | |||
GELOGI("loop_per_iter is null."); | |||
} | |||
if (loop_cond != 0) { | |||
GELOGI("loop_cond exists."); | |||
GELOGI("Exists loop_cond."); | |||
op_mapping_info.set_loop_cond_addr(static_cast<uint64_t>(loop_cond)); | |||
} else { | |||
GELOGI("loop_cond is null."); | |||
@@ -253,7 +253,7 @@ Status DumpOp::LaunchDumpOp() { | |||
} | |||
if (device_id < 0) { | |||
GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Check][DeviceId]Failed, device_id %d", device_id); | |||
REPORT_INNER_ERROR("E19999","Check device_id %d failed", device_id); | |||
REPORT_INNER_ERROR("E19999", "Check device_id %d failed", device_id); | |||
return ACL_ERROR_GE_INTERNAL_ERROR; | |||
} | |||
toolkit::aicpu::dump::OpMappingInfo op_mapping_info; | |||
@@ -72,8 +72,7 @@ Status CheckArgsForFracZToNchw(const TransArgs &args) { | |||
if (src_shape.at(kFracZHWC1) != dst_shape.at(kNchwH) * dst_shape.at(kNchwW) * c1 || | |||
src_shape.at(kFracZC0) != c0 || src_shape.at(kFracZNi) != kNiSize || src_shape.at(kFracZN0) != n0) { | |||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, | |||
"[Check][Shape]Failed to check relationship between src and dst shape, " | |||
"src shape %s, dst shape %s", | |||
"[Check][Shape]Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | |||
ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | |||
REPORT_INNER_ERROR("E19999", "Failed to check relationship between src and dst shape, " | |||
"src shape %s, dst shape %s", | |||
@@ -138,9 +137,9 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||
"[Operate][Memory]Failed to copy data from FracZ offset %ld to " | |||
"NCHW[%ld, %ld, %ld, %ld] offset %ld, err-code %d", | |||
src_offset, n_idx, c_idx, h_idx, w_idx, dst_offset, ret); | |||
REPORT_CALL_ERROR("E19999","Failed to copy data from FracZ offset %ld to " | |||
REPORT_CALL_ERROR("E19999", "Failed to copy data from FracZ offset %ld to " | |||
"NCHW[%ld, %ld, %ld, %ld] offset %ld, err-code %d", | |||
src_offset, n_idx, c_idx, h_idx, w_idx, dst_offset, ret ); | |||
src_offset, n_idx, c_idx, h_idx, w_idx, dst_offset, ret); | |||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||
} | |||
} | |||
@@ -44,23 +44,20 @@ Status CheckArgsForFracZToNhwc(const TransArgs &args) { | |||
GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "[Check][DataType]Failed, " | |||
"shape from FORMAT_FRACTAL_Z to NCHW, invalid data type %s", | |||
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||
REPORT_INNER_ERROR("E19999", "Failed to trans shape from FORMAT_FRACTAL_Z to NCHW, " | |||
"invalid data type %s", | |||
REPORT_INNER_ERROR("E19999", "Failed to trans shape from FORMAT_FRACTAL_Z to NCHW, invalid data type %s", | |||
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||
} | |||
if (!CheckShapeValid(src_shape, kFracZDimsNum)) { | |||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, src shape %s", | |||
ShapeToString(src_shape).c_str()); | |||
REPORT_CALL_ERROR("E19999", "Src shape %s check invalid", | |||
ShapeToString(src_shape).c_str()); | |||
REPORT_CALL_ERROR("E19999", "Src shape %s check invalid", ShapeToString(src_shape).c_str()); | |||
return ACL_ERROR_GE_SHAPE_INVALID; | |||
} | |||
if (!CheckShapeValid(dst_shape, kNhwcDimsNum)) { | |||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, dst shape %s", | |||
ShapeToString(dst_shape).c_str()); | |||
REPORT_CALL_ERROR("E19999", "Dst shape %s check invalid", | |||
ShapeToString(dst_shape).c_str()); | |||
REPORT_CALL_ERROR("E19999", "Dst shape %s check invalid", ShapeToString(dst_shape).c_str()); | |||
return ACL_ERROR_GE_SHAPE_INVALID; | |||
} | |||
int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | |||
@@ -138,7 +135,7 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size | |||
"[Operate][Memory]Failed to copy data from FracZ offset %ld to " | |||
"NCHW[%ld, %ld, %ld, %ld] offset %ld, err-code %d", | |||
src_offset, n_idx, c_idx, h_idx, w_idx, dst_offset, ret); | |||
REPORT_CALL_ERROR("E19999","Failed to copy data from FracZ offset %ld to " | |||
REPORT_CALL_ERROR("E19999", "Failed to copy data from FracZ offset %ld to " | |||
"NCHW[%ld, %ld, %ld, %ld] offset %ld, err-code %d", | |||
src_offset, n_idx, c_idx, h_idx, w_idx, dst_offset, ret); | |||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||
@@ -185,7 +182,7 @@ Status FormatTransferFracZNhwc::TransFormat(const TransArgs &args, TransResult & | |||
ShapeToString(args.src_shape).c_str(), | |||
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||
ShapeToString(args.dst_shape).c_str(), total_size, ret); | |||
REPORT_CALL_ERROR("E19999","Failed to get data after trans, src shape %s, data type %s, " | |||
REPORT_CALL_ERROR("E19999", "Failed to get data after trans, src shape %s, data type %s, " | |||
"dst shape %s, memory size %ld, error_code %u", | |||
ShapeToString(args.src_shape).c_str(), | |||
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||
@@ -112,11 +112,10 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||
total_size, ShapeToString(args.dst_shape).c_str(), | |||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||
REPORT_CALL_ERROR("E19999", "Failed to alloc the memory for dst buf %ld, " | |||
"shape %s when trans format from %s to %s", | |||
total_size, ShapeToString(args.dst_shape).c_str(), | |||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||
REPORT_CALL_ERROR("E19999", "Failed to alloc the memory for dst buf %ld, shape %s when trans format from %s to %s", | |||
total_size, ShapeToString(args.dst_shape).c_str(), | |||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||
return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||
} | |||
@@ -47,7 +47,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransFormat(const TransArg | |||
GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Shape]Failed, input data is null " | |||
"or shape size not euqal to 0, src_shape %s", | |||
ShapeToString(args.src_shape).c_str()); | |||
REPORT_CALL_ERROR("E19999","Failed to check shape, input data is null " | |||
REPORT_CALL_ERROR("E19999", "Failed to check shape, input data is null " | |||
"or shape size not equal to 0, src_shape %s", | |||
ShapeToString(args.src_shape).c_str()); | |||
return ACL_ERROR_GE_PARAM_INVALID; | |||
@@ -79,7 +79,8 @@ Status ModelHelper::SaveModelPartition(std::shared_ptr<OmFileSaveHelper> &om_fil | |||
Status ModelHelper::SaveSizeToModelDef(const GeModelPtr &ge_model) { | |||
vector<int64_t> om_info; | |||
auto ge_model_weight = ge_model->GetWeight(); | |||
GELOGD("SaveSizeToModelDef weight_data_size is %zu, %p", ge_model_weight.GetSize(), ge_model_weight.GetData()); | |||
GELOGD("SaveSizeToModelDef weight_data_size is %zu, ge_model_weight data is %p", ge_model_weight.GetSize(), | |||
ge_model_weight.GetData()); | |||
om_info.push_back(ge_model_weight.GetSize()); | |||
TBEKernelStore tbe_kernel_store = ge_model->GetTBEKernelStore(); | |||
@@ -284,7 +285,7 @@ Status ModelHelper::SaveAllModelPartiton(std::shared_ptr<OmFileSaveHelper>& om_f | |||
if (SaveModelWeights(om_file_save_helper, ge_model, model_index) != SUCCESS) { | |||
GELOGE(FAILED, "[Save][ModelWeights]Failed, model %s, model index %zu", | |||
ge_model->GetName().c_str(), model_index); | |||
REPORT_CALL_ERROR("E19999","ModelHelper save mode weights failed, model %s, model index %zu", | |||
REPORT_CALL_ERROR("E19999", "ModelHelper save mode weights failed, model %s, model index %zu", | |||
ge_model->GetName().c_str(), model_index); | |||
return FAILED; | |||
} | |||
@@ -441,7 +442,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmRoo | |||
GELOGE(INTERNAL_ERROR, "[Save][AllModelPartition]Failed, model name %s, cur_index %zu", | |||
model_name.c_str(), cur_index); | |||
REPORT_CALL_ERROR("E19999", "Save all model %s partition failed, cur_index %zu", | |||
model_name.c_str(), cur_index); | |||
model_name.c_str(), cur_index); | |||
return INTERNAL_ERROR; | |||
} | |||
} | |||
@@ -459,7 +460,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmRoo | |||
GELOGE(FAILED, "[Save][Model]OmFileSaveHelper save model eturn fail, output_file %s", | |||
output_file.c_str()); | |||
REPORT_CALL_ERROR("E19999", "OmFileSaveHelper save model return fail, output_file %s", | |||
output_file.c_str()); | |||
output_file.c_str()); | |||
return FAILED; | |||
} | |||
return SUCCESS; | |||
@@ -601,7 +602,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadModel(c | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadRootModel(const ge::ModelData &model_data) { | |||
if (model_data.model_data == nullptr || model_data.model_len == 0) { | |||
GELOGE(ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID, "[Load][RootModel] " | |||
"Model_data is nullptr or model_data_size is 0"); | |||
"Model_data is nullptr or model data is empty."); | |||
REPORT_INNER_ERROR("E19999", "Load root model failed, model_data is nullptr or its size is 0"); | |||
return ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID; | |||
} | |||
@@ -628,7 +629,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadRootMod | |||
//model verison 1.0 file header does not have model_num member | |||
is_unknown_shape_model_ = file_header_->version >= ge::MODEL_VERSION && | |||
file_header_->model_num > kStatiOmFileModelNum; | |||
GELOGD("cur om model is ge root model or no %d, model version %u", is_unknown_shape_model_, file_header_->version); | |||
GELOGD("Cur om model is ge root model or no %d, model version %u", is_unknown_shape_model_, file_header_->version); | |||
OmFileLoadHelper om_load_helper; | |||
if (is_unknown_shape_model_) { | |||
@@ -650,7 +651,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadRootMod | |||
GELOGE(status, "[Generate][GERootModel]Failed"); | |||
return status; | |||
} | |||
GELOGD("in ModelHelper::LoadRootModel, is_assign_model_ is setted to true!"); | |||
GELOGD("In ModelHelper::LoadRootModel, is_assign_model_ is setted to true!"); | |||
is_assign_model_ = true; | |||
return SUCCESS; | |||
} | |||
@@ -790,7 +791,7 @@ Status ModelHelper::LoadWeights(OmFileLoadHelper &om_load_helper) { | |||
if (om_load_helper.GetModelPartition(ModelPartitionType::WEIGHTS_DATA, partition) != SUCCESS) { | |||
GELOGE(FAILED, "[Get][ModelWeightPartition]Failed, GetWeight size:%u", partition.size); | |||
REPORT_CALL_ERROR("E19999", "[Get][ModelPartition]Failed, GetWeight size:%u", | |||
partition.size); | |||
partition.size); | |||
return FAILED; | |||
} | |||
ge::Buffer weight = ge::Buffer::CopyFrom(partition.data, partition.size); | |||
@@ -805,7 +806,7 @@ Status ModelHelper::LoadWeights(OmFileLoadHelper &om_load_helper, GeModelPtr &cu | |||
if (om_load_helper.GetModelPartition(ModelPartitionType::WEIGHTS_DATA, partition, mode_index) != SUCCESS) { | |||
GELOGE(FAILED, "[Get][ModelPartition]Failed, GetWeight size:%u", partition.size); | |||
REPORT_CALL_ERROR("E19999", "[Get][ModelPartition]Failed, GetWeight size:%u", | |||
partition.size); | |||
partition.size); | |||
return FAILED; | |||
} | |||
ge::Buffer weight = ge::Buffer::CopyFrom(partition.data, partition.size); | |||
@@ -444,17 +444,16 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status | |||
OpUtils::GetShapeDataFromConstTensor(const ConstGeTensorPtr &tensor, DataType type, std::vector<int64_t> &dims) { | |||
if (tensor == nullptr) { | |||
GELOGE(PARAM_INVALID, "[Check][Param]Input tensor is nullptr"); | |||
REPORT_INNER_ERROR("E19999","Input tensor is nullptr"); | |||
REPORT_INNER_ERROR("E19999", "Input tensor is nullptr"); | |||
return PARAM_INVALID; | |||
} | |||
// If the tensor data is a vector, the shape dimension must be 1 | |||
if (tensor->GetTensorDesc().GetShape().GetDims().size() > 1) { | |||
GELOGE(PARAM_INVALID, "[Check][Param]The dimension of the input tensor shape " | |||
"cannot be more than 1, it is %zu", | |||
GELOGE(PARAM_INVALID, "[Check][Param]The dimension of the input tensor shape cannot be more than 1, it is %zu", | |||
tensor->GetTensorDesc().GetShape().GetDims().size()); | |||
REPORT_CALL_ERROR("E19999", "The dimension of the input tensor shape %zu invalid, " | |||
"more than 1", tensor->GetTensorDesc().GetShape().GetDims().size()); | |||
REPORT_CALL_ERROR("E19999", "The dimension of the input tensor shape %zu invalid, more than 1", | |||
tensor->GetTensorDesc().GetShape().GetDims().size()); | |||
return PARAM_INVALID; | |||
} | |||
@@ -473,8 +472,8 @@ OpUtils::GetShapeDataFromConstTensor(const ConstGeTensorPtr &tensor, DataType ty | |||
dims.push_back(shape_data[i]); | |||
} | |||
} else { | |||
GELOGE(PARAM_INVALID, "[Check][DataType]Invalid, type only can be DT_INT32 or DT_INT64, " | |||
"type is %s", TypeUtils::DataTypeToSerialString(type).c_str()); | |||
GELOGE(PARAM_INVALID, "[Check][DataType]Invalid, type only can be DT_INT32 or DT_INT64, type is %s", | |||
TypeUtils::DataTypeToSerialString(type).c_str()); | |||
REPORT_INNER_ERROR("E19999", "Data type %s check invalid, only can be DT_INT32 or DT_INT64", | |||
TypeUtils::DataTypeToSerialString(type).c_str()); | |||
return PARAM_INVALID; | |||
@@ -304,7 +304,7 @@ std::string DNNEngineManager::GetHostCpuEngineName(const std::vector<OpInfo> &op | |||
GELOGE(FAILED, "[Get][HostCpuEngineName]Failed, HostCpuEngine not support [%s, %s]", | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||
REPORT_INNER_ERROR("E19999", "Get HostCpuEngineName failed, HostCpuEngine not support [%s, %s]", | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||
return ""; | |||
} | |||
@@ -436,7 +436,7 @@ Status DNNEngineManager::ParserEngineMessage(const json engines_json, const std: | |||
GELOGE(FAILED, "[Check][Param]There are the same engine %s message in the json file", | |||
engine_id.c_str()); | |||
REPORT_INNER_ERROR("E19999", "There are the same engine %s message in the json file", | |||
engine_id.c_str()); | |||
engine_id.c_str()); | |||
return FAILED; | |||
} | |||
engines.emplace(engine_id, engine_conf_ptr); | |||
@@ -684,7 +684,8 @@ Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node, | |||
bool is_allocated_first_input = is_continuous_input_allocated && (in_data_anchor->GetIdx() == 0); | |||
if (is_allocated_first_input) { | |||
std::map<int32_t, int32_t> out2ins; | |||
GE_CHK_STATUS_RET(TryGetNodeRefIndexes(node, out2ins), "[Get][RefIndexes]fail for node: %s", node->GetName().c_str()); | |||
GE_CHK_STATUS_RET(TryGetNodeRefIndexes(node, out2ins), "[Get][RefIndexes]fail for node: %s", | |||
node->GetName().c_str()); | |||
// output is beginning offset, set offset for input; only support this case now | |||
if ((out2ins.size() == 1) && (out2ins.begin()->second == 0) && (reverse_refresh)) { | |||
auto peer_output_offset = output_list.at(peer_out_data_anchor->GetIdx()); | |||
@@ -246,7 +246,8 @@ Status ModelBuilder::SetInputOutputDesc() { | |||
} | |||
// if user set input node format ND, the expected node for data and netoutput format is ND in | |||
// final graph. | |||
if ((compute_graph_->GetParentGraph() == nullptr) && (GetLocalOmgContext().format == domi::DOMI_TENSOR_ND) && (!node_op_desc->HasAttr("_is_single_op")) && | |||
if ((compute_graph_->GetParentGraph() == nullptr) && (GetLocalOmgContext().format == domi::DOMI_TENSOR_ND) && | |||
(!node_op_desc->HasAttr("_is_single_op")) && | |||
((node_op_desc->GetType() == DATA_TYPE) || (node_op_desc->GetType() == NETOUTPUT))) { | |||
auto inputDescsPtr = node_op_desc->GetAllInputsDescPtr(); | |||
auto outputDescsPtr = node_op_desc->GetAllOutputsDescPtr(); | |||
@@ -193,23 +193,29 @@ Status SetCyclicDependenceFlag(const ge::NodePtr &node) { | |||
/// | |||
/// @brief set op next_iteration name | |||
/// @param [in] node | |||
/// @param [in] next | |||
/// @param [in] Merge Node | |||
/// @param [in] NextIteration Node | |||
/// @return Status | |||
/// | |||
Status SetNextIteration(const ge::NodePtr &node, const std::string &next) { | |||
Status SetNextIteration(const NodePtr &node, const NodePtr &next) { | |||
GE_CHECK_NOTNULL(node); | |||
OpDescPtr tmp_desc = node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(tmp_desc); | |||
GE_CHECK_NOTNULL(next); | |||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||
GE_CHECK_NOTNULL(next->GetOpDesc()); | |||
if (!AttrUtils::SetStr(tmp_desc, ge::ATTR_NAME_NEXT_ITERATION, next)) { | |||
REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_NEXT_ITERATION.c_str(), | |||
node->GetName().c_str(), node->GetType().c_str()); | |||
GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_NEXT_ITERATION.c_str(), | |||
node->GetName().c_str(), node->GetType().c_str()); | |||
return FAILED; | |||
} | |||
const auto SetIterationName = [](const OpDescPtr &op_desc, const std::string &name) { | |||
if (!AttrUtils::SetStr(op_desc, ATTR_NAME_NEXT_ITERATION, name)) { | |||
REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_NEXT_ITERATION.c_str(), | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||
GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_NEXT_ITERATION.c_str(), | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||
return FAILED; | |||
} | |||
return SUCCESS; | |||
}; | |||
GE_CHK_STATUS_RET_NOLOG(SetIterationName(node->GetOpDesc(), next->GetName())); | |||
GE_CHK_STATUS_RET_NOLOG(SetIterationName(next->GetOpDesc(), node->GetName())); | |||
return SUCCESS; | |||
} | |||
@@ -96,11 +96,11 @@ Status SetCyclicDependenceFlag(const ge::NodePtr &node); | |||
/// | |||
/// @brief set op next_iteration name | |||
/// @param [in] node | |||
/// @param [in] next | |||
/// @param [in] Merge Node | |||
/// @param [in] NextIteration Node | |||
/// @return Status | |||
/// | |||
Status SetNextIteration(const ge::NodePtr &node, const std::string &next); | |||
Status SetNextIteration(const NodePtr &node, const NodePtr &next); | |||
/// | |||
/// @brief Align the memory | |||
@@ -704,7 +704,7 @@ Status GraphExecutor::GetCurShape(const uint32_t model_id, std::vector<int64_t> | |||
} | |||
Status GraphExecutor::GetOpAttr(uint32_t model_id, const std::string &op_name, const std::string &attr_name, | |||
std::string &attr_value) { | |||
std::string &attr_value) { | |||
auto model_manager = ge::ModelManager::GetInstance(); | |||
GE_CHECK_NOTNULL(model_manager); | |||
Status ret = model_manager->GetOpAttr(model_id, op_name, attr_name, attr_value); | |||
@@ -886,6 +886,7 @@ Status GraphManager::PreRunOptimizeOriginalGraph(const GraphNodePtr &graph_node, | |||
GM_RUN_AND_DUMP_PERF("OptimizeSwitchOp", stages.preparer.SwitchOpOptimize, compute_graph); | |||
} | |||
GM_RUN_AND_DUMP_PERF("Optimize1", OptimizeStage1, compute_graph); | |||
GM_RUN_AND_DUMP_PERF("OptimizeAfterStage1", stages.optimizer.OptimizeAfterStage1, compute_graph); | |||
GM_RUN_AND_DUMP_PERF("InferShape2", compute_graph->InferShapeInNeed); | |||
PassManager graph_pass; | |||
@@ -3118,7 +3119,7 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { | |||
GraphNodePtr graph_node = nullptr; | |||
Status ret = graph_manager->GetGraphNode(args.graph_id, graph_node); | |||
if (ret != SUCCESS) { | |||
ReturnError(graph_manager, args.callback, GE_GRAPH_ALREADY_RUNNING, | |||
ReturnError(graph_manager, args.callback, GE_GRAPH_GRAPH_NODE_NULL, | |||
"[RunGraph] graph not exist, graph_id=" + std::to_string(args.graph_id)); | |||
return; | |||
} | |||
@@ -3143,7 +3144,7 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { | |||
graph_node->Lock(); | |||
if (graph_node->GetRunFlag()) { | |||
ReturnError(graph_manager, args.callback, GE_GRAPH_GRAPH_NODE_NULL, | |||
ReturnError(graph_manager, args.callback, GE_GRAPH_ALREADY_RUNNING, | |||
"[RunGraph] graph already running, graph id=" + std::to_string(args.graph_id)); | |||
graph_node->Unlock(); | |||
return; | |||
@@ -489,7 +489,7 @@ Status VarManager::UpdateVarMemSize(rtMemType_t memory_type, int64_t mem_size) { | |||
mem_resource = MemResource::BuildMemResourceFromType(memory_type); | |||
if (mem_resource == nullptr) { | |||
REPORT_CALL_ERROR("E19999", "memory_type:%d invalid or New MemResource fail, session_id:%lu", | |||
memory_type, session_id_); | |||
memory_type, session_id_); | |||
GELOGE(ge::INTERNAL_ERROR, "[Alloc][MemResource] failed, memory_type:%u, session_id:%lu", | |||
memory_type, session_id_); | |||
return ge::INTERNAL_ERROR; | |||
@@ -275,7 +275,8 @@ Status HcomOmeUtil::GetHcclOperationType(const ge::ConstOpDescPtr &op_desc, Hccl | |||
"check invalid", ATTR_HOROVOD_ATTR_REDUCE_TYPE.c_str(), | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), horovod_op_type); | |||
GELOGE(PARAM_INVALID, "[Check][Param] Attr:%s in Op:%s(%s), horovod_op_type value:%ld is not support now", | |||
ATTR_HOROVOD_ATTR_REDUCE_TYPE.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(), horovod_op_type); | |||
ATTR_HOROVOD_ATTR_REDUCE_TYPE.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(), | |||
horovod_op_type); | |||
return PARAM_INVALID; | |||
} | |||
op_type = iter->second; | |||
@@ -155,7 +155,7 @@ Status GraphOptimize::OptimizeOriginalGraph(ComputeGraphPtr &compute_graph) { | |||
} | |||
auto graph_optimizer = instance_ptr->OpsKernelManagerObj().GetAllGraphOptimizerObjsByPriority(); | |||
GELOGI("optimize by opskernel in original graph optimize phase. num of graph_optimizer is %lu.", | |||
GELOGI("optimize by opskernel in original graph optimize phase. num of graph_optimizer is %zu.", | |||
graph_optimizer.size()); | |||
string exclude_core_Type = (core_type_ == kVectorCore) ? kAicoreEngine : kVectorEngine; | |||
GELOGD("[OptimizeOriginalGraph]: engine type will exclude: %s", exclude_core_Type.c_str()); | |||
@@ -194,7 +194,7 @@ Status GraphOptimize::OptimizeOriginalGraphJudgeInsert(ComputeGraphPtr &compute_ | |||
} | |||
auto graph_optimizer = instance_ptr->OpsKernelManagerObj().GetAllGraphOptimizerObjsByPriority(); | |||
GELOGI("optimize by opskernel in original graph optimize phase. num of graph_optimizer is %lu.", | |||
GELOGI("optimize by opskernel in judging insert phase. num of graph_optimizer is %zu.", | |||
graph_optimizer.size()); | |||
string exclude_core_Type = (core_type_ == kVectorCore) ? kAicoreEngine : kVectorEngine; | |||
if (graph_optimizer.size() != 0) { | |||
@@ -294,6 +294,46 @@ Status GraphOptimize::OptimizeGraphBeforeBuildForRts(ComputeGraphPtr &compute_gr | |||
return ret; | |||
} | |||
Status GraphOptimize::OptimizeAfterStage1(ComputeGraphPtr &compute_graph) { | |||
GE_CHECK_NOTNULL(compute_graph); | |||
GELOGD("OptimizeAfterStage1 in"); | |||
if (GetContext().GetHostExecFlag()) { | |||
// graph exec on host, no need OptimizeAfterStage1 | |||
return SUCCESS; | |||
} | |||
Status ret = SUCCESS; | |||
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||
REPORT_INNER_ERROR("E19999", "Gelib not init before, check invalid"); | |||
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "OptimizeAfterStage1 failed."); | |||
return GE_CLI_GE_NOT_INITIALIZED; | |||
} | |||
auto graph_optimizer = instance_ptr->OpsKernelManagerObj().GetAllGraphOptimizerObjsByPriority(); | |||
GELOGI("Optimize by ops kernel in after stage1 phase, num of graph_optimizer is %zu.", graph_optimizer.size()); | |||
string exclude_core_type = (core_type_ == kVectorCore) ? kAicoreEngine : kVectorEngine; | |||
if (graph_optimizer.size() != 0) { | |||
for (auto iter = graph_optimizer.begin(); iter != graph_optimizer.end(); ++iter) { | |||
if (iter->first == exclude_core_type) { | |||
GELOGI("[OptimizeAfterStage1]: engine type will exclude:%s.", exclude_core_type.c_str()); | |||
continue; | |||
} | |||
#ifndef ONLY_COMPILE_OPEN_SRC | |||
GELOGI("Begin to optimize graph after stage1 by engine %s.", iter->first.c_str()); | |||
ret = (iter->second)->OptimizeAfterStage1(*compute_graph); | |||
#endif | |||
if (ret != SUCCESS) { | |||
REPORT_INNER_ERROR("E19999", "Call OptimizeAfterStage1 failed, ret:%d, engine_name:%s, " | |||
"graph_name:%s.", ret, iter->first.c_str(), compute_graph->GetName().c_str()); | |||
GELOGE(ret, "[OptimizeAfterStage1]: graph optimize failed, ret:%d.", ret); | |||
return ret; | |||
} | |||
} | |||
} | |||
return ret; | |||
} | |||
Status GraphOptimize::SetOptions(const ge::GraphManagerOptions &options) { | |||
if (options.framework_type >= static_cast<int32_t>(domi::FrameworkType::FRAMEWORK_RESERVED)) { | |||
REPORT_INNER_ERROR("E19999", "Param framework_type:%d in option check invalid", | |||
@@ -58,6 +58,9 @@ class GraphOptimize { | |||
// for rts optimize before build to add attr and insert memcpy op | |||
Status OptimizeGraphBeforeBuildForRts(ComputeGraphPtr &compute_graph); | |||
// optimize whole graph, using after stage1 | |||
Status OptimizeAfterStage1(ComputeGraphPtr &graph); | |||
// set options | |||
Status SetOptions(const GraphManagerOptions &options); | |||
@@ -1,5 +1,5 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
@@ -22,6 +22,7 @@ | |||
#include "graph/optimize/graph_optimize.h" | |||
#include "graph/utils/graph_utils.h" | |||
#include "graph/utils/node_utils.h" | |||
#include "graph/utils/op_desc_utils.h" | |||
namespace { | |||
using namespace ge; | |||
@@ -32,12 +33,14 @@ const int kCaseReadOnly = 0; | |||
const int kCaseScopeWriteable = 2; | |||
const int kCaseWriteable = 3; | |||
const int kCaseInvalidRWType = 5; | |||
// attr _input_mutable = true means node will modify its input in runtime | |||
const char *const kModifyInput = "_input_mutable"; | |||
// rw type of input. | |||
enum class InputRWType { | |||
kReadOnly, // Normal op input only read | |||
kWriteable, // Op like Assign/ApplyMomentum | |||
kScopeWriteable, // Op like hcom_allreduce, it will modify input ,but not expect take effect on pre ouput | |||
kScopeWriteable, // Op like hcom_allreduce/while, it will modify input ,but not expect take effect on pre ouput | |||
kInvalidRWType | |||
}; | |||
// rw type of output | |||
@@ -154,7 +157,7 @@ bool IsSubgraphOutputNode(const NodePtr &node) { | |||
return true; | |||
} | |||
NodePtr CreateIdentityAfterSrcNode(const Node &src_node, int out_anchor_idx) { | |||
NodePtr AddIdentityToGraph(const Node &src_node, int out_anchor_idx) { | |||
if (src_node.GetOpDesc() == nullptr) { | |||
return nullptr; | |||
} | |||
@@ -162,30 +165,19 @@ NodePtr CreateIdentityAfterSrcNode(const Node &src_node, int out_anchor_idx) { | |||
auto next_num = identity_num.fetch_add(1); | |||
// 1. create new identity op desc | |||
string identity_name = src_node.GetName() + "_" + IDENTITY + std::to_string(next_num); | |||
auto identity_opdesc = MakeShared<OpDesc>(identity_name, IDENTITY); | |||
if (identity_opdesc == nullptr) { | |||
GELOGE(OUT_OF_MEMORY, "Failed to insert identity node, name %s", identity_name.c_str()); | |||
return nullptr; | |||
} | |||
OpDescBuilder op_desc_builder(identity_name, IDENTITY); | |||
auto data_desc = src_node.GetOpDesc()->GetOutputDesc(out_anchor_idx); | |||
// 2. add input_desc & output_desc for new identity | |||
Status ret = identity_opdesc->AddInputDesc("x", data_desc); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "Add Input desc failed for new identity %s.", identity_name.c_str()); | |||
return nullptr; | |||
} | |||
ret = identity_opdesc->AddOutputDesc("y", data_desc); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "Add Output desc failed for new Identity %s.", identity_name.c_str()); | |||
return nullptr; | |||
} | |||
auto identity_op_desc = op_desc_builder.AddInput("x", data_desc) | |||
.AddOutput("y", data_desc) | |||
.Build(); | |||
GELOGI("Insert new Identity node %s.", identity_name.c_str()); | |||
auto graph = src_node.GetOwnerComputeGraph(); | |||
if (graph == nullptr) { | |||
GELOGE(GRAPH_PARAM_INVALID, "Node %s owner compute graph is null.", src_node.GetName().c_str()); | |||
return nullptr; | |||
} | |||
return graph->AddNode(identity_opdesc); | |||
return graph->AddNode(identity_op_desc); | |||
} | |||
OutputRWType GetOutputRWTypeByIndex(const Node &node, uint32_t index) { | |||
@@ -274,8 +266,6 @@ InputRWType GetInputRWTypeByIndex(const Node &node, uint32_t index) { | |||
// single node without sub graph | |||
return GetSingleNodeInputRWTypeByIndex(node, index); | |||
} else { | |||
// node with sub graph | |||
std::set<int> node_rw_type_set; | |||
auto data_node_vec = NodeUtils::GetSubgraphDataNodesByIndex(node, index); | |||
// get all input data node in subgraph | |||
std::set<int> anchor_rw_type_set; | |||
@@ -345,12 +335,24 @@ Status MarkRWTypeForSubgraph(const ComputeGraphPtr &sub_graph) { | |||
auto parent_node = sub_graph->GetParentNode(); | |||
if (pre_output_rw_type == OutputRWType::kWriteable && parent_node->GetType() != PARTITIONEDCALL) { | |||
// insert identity | |||
auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); | |||
auto identity_node = AddIdentityToGraph(*pre_node, pre_out_anchor->GetIdx()); | |||
GE_CHECK_NOTNULL(identity_node); | |||
auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "Fail to insert identity"); | |||
return ret; | |||
if (GraphUtils::InsertNodeAfter(pre_out_anchor, {in_data_anchor}, identity_node) != GRAPH_SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Insert Identity node %s(%s) between %s(%s) -> %s(%s) failed.", | |||
identity_node->GetName().c_str(), | |||
identity_node->GetType().c_str(), | |||
pre_node->GetName().c_str(), | |||
pre_node->GetType().c_str(), | |||
node->GetName().c_str(), | |||
node->GetType().c_str()); | |||
GELOGE(FAILED, "Insert Identity node %s(%s) between %s(%s) -> %s(%s) failed.", | |||
identity_node->GetName().c_str(), | |||
identity_node->GetType().c_str(), | |||
pre_node->GetName().c_str(), | |||
pre_node->GetType().c_str(), | |||
node->GetName().c_str(), | |||
node->GetType().c_str()); | |||
return FAILED; | |||
} | |||
GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | |||
pre_node->GetName().c_str(), node->GetName().c_str()); | |||
@@ -505,34 +507,24 @@ Status SplitIdentityAlongAnchor(const OutDataAnchorPtr &out_data_anchor, const I | |||
auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); | |||
GE_CHECK_NOTNULL(peer_in_data_node); | |||
auto input_rw_type = GetInputRWTypeByIndex(*peer_in_data_node, peer_in_data_anchor->GetIdx()); | |||
auto ret = out_data_anchor->Unlink(peer_in_data_anchor); | |||
auto old_identity = out_data_anchor->GetOwnerNode(); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "Failed to unlink from %s %dth out to %s.", old_identity->GetName().c_str(), out_data_anchor->GetIdx(), | |||
peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); | |||
return ret; | |||
} | |||
if (input_rw_type == InputRWType::kScopeWriteable || input_rw_type == InputRWType::kWriteable) { | |||
auto new_identity = CreateIdentityAfterSrcNode(*pre_node, pre_out_data_anchor->GetIdx()); | |||
auto new_identity = AddIdentityToGraph(*pre_node, pre_out_data_anchor->GetIdx()); | |||
GE_CHECK_NOTNULL(new_identity); | |||
if (GraphUtils::AddEdge(pre_out_data_anchor, new_identity->GetInDataAnchor(kIdentityAnchorIndex)) != SUCCESS | |||
|| GraphUtils::AddEdge(new_identity->GetOutDataAnchor(kIdentityAnchorIndex), peer_in_data_anchor) != SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "Failed to insert Identity between node %s and %s", | |||
pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), | |||
peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
// 2. copy in-control-edge from dst to Identity | |||
if (GraphUtils::CopyInCtrlEdges(peer_in_data_node, new_identity) != SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "Failed to copy in_control edges from node %s to %s", peer_in_data_node->GetName().c_str(), | |||
new_identity->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
auto ret = GraphUtils::InsertNodeBefore(peer_in_data_anchor, new_identity, kIdentityAnchorIndex, | |||
kIdentityAnchorIndex); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "Failed to insert Identity %s before %s %dth input.", | |||
new_identity->GetName().c_str(), | |||
peer_in_data_anchor->GetOwnerNode()->GetName().c_str(), | |||
peer_in_data_anchor->GetIdx()); | |||
return ret; | |||
} | |||
GELOGI("Node %s intput rw type is %s. Insert Identity between %s and %s.", peer_in_data_node->GetName().c_str(), | |||
InputRWTypeToSerialString(input_rw_type).c_str(), pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), | |||
peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); | |||
} else { | |||
(void) out_data_anchor->Unlink(peer_in_data_anchor); | |||
// copy control edge to pre and peer node | |||
if (GraphUtils::CopyInCtrlEdges(old_identity, peer_in_data_node) != SUCCESS | |||
|| GraphUtils::CopyOutCtrlEdges(old_identity, pre_node) != SUCCESS) { | |||
@@ -613,16 +605,14 @@ Status InsertIdentityAsNeeded(const NodePtr &node) { | |||
GELOGD("No need insert Identity."); | |||
continue; | |||
case INSERT_IDENTITY: | |||
auto identity_node = CreateIdentityAfterSrcNode(*node, out_data_anchor->GetIdx()); | |||
if (identity_node == nullptr) { | |||
GELOGE(FAILED, "Create identity node failed."); | |||
return FAILED; | |||
} | |||
auto ret = GraphUtils::InsertNodeBetweenDataAnchors(out_data_anchor, peer_in_data_anchor, identity_node); | |||
if (ret != GRAPH_SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "Failed to insert reshape between node %s and %s", node->GetName().c_str(), | |||
peer_in_node->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
auto identity_node = AddIdentityToGraph(*node, out_data_anchor->GetIdx()); | |||
GE_CHECK_NOTNULL(identity_node); | |||
auto ret = GraphUtils::InsertNodeBefore(peer_in_data_anchor, identity_node, kIdentityAnchorIndex, | |||
kIdentityAnchorIndex); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "Fail to insert %s before %s %dth input.", identity_node->GetName().c_str(), | |||
peer_in_data_anchor->GetOwnerNode()->GetName().c_str(), peer_in_data_anchor->GetIdx()); | |||
return ret; | |||
} | |||
GELOGI("Insert Identity between %s and %s to handle memory conflict.", node->GetName().c_str(), | |||
peer_in_node->GetName().c_str()); | |||
@@ -633,28 +623,35 @@ Status InsertIdentityAsNeeded(const NodePtr &node) { | |||
return SUCCESS; | |||
} | |||
Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) { | |||
for (const auto &node : compute_graph->GetDirectNode()) { | |||
if (node->GetType() == HCOMALLREDUCE) { | |||
std::set<OutDataAnchorPtr> pre_out_anchor_set; | |||
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
GE_CHECK_NOTNULL(pre_out_anchor); | |||
if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) { | |||
pre_out_anchor_set.emplace(pre_out_anchor); | |||
continue; | |||
} | |||
// need insert identity | |||
auto pre_node = pre_out_anchor->GetOwnerNode(); | |||
auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); | |||
GE_CHECK_NOTNULL(identity_node); | |||
auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); | |||
GE_CHK_STATUS_RET(ret, "Fail to insert identity."); | |||
GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | |||
pre_node->GetName().c_str(), node->GetName().c_str()); | |||
} | |||
} | |||
} | |||
return SUCCESS; | |||
for (const auto &node : compute_graph->GetDirectNode()) { | |||
bool mutable_input_flag = false; | |||
(void)AttrUtils::GetBool(node->GetOpDesc(), kModifyInput, mutable_input_flag); | |||
if (!mutable_input_flag) { | |||
continue; | |||
} | |||
std::set<OutDataAnchorPtr> pre_out_anchor_set; | |||
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
GE_CHECK_NOTNULL(pre_out_anchor); | |||
if (pre_out_anchor_set.insert(pre_out_anchor).second) { | |||
continue; | |||
} | |||
// need insert identity | |||
auto pre_node = pre_out_anchor->GetOwnerNode(); | |||
auto identity_node = AddIdentityToGraph(*pre_node, pre_out_anchor->GetIdx()); | |||
GE_CHECK_NOTNULL(identity_node); | |||
auto ret = | |||
GraphUtils::InsertNodeBefore(in_data_anchor, identity_node, kIdentityAnchorIndex, kIdentityAnchorIndex); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "Failed to insert node %s before %s %dth input.", identity_node->GetName().c_str(), | |||
node->GetName().c_str(), in_data_anchor->GetIdx()); | |||
return ret; | |||
} | |||
GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | |||
pre_node->GetName().c_str(), node->GetName().c_str()); | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
} // namespace | |||
@@ -387,6 +387,9 @@ void DynamicShapePartitioner::MergeClustersUnknownShape() { | |||
if (!in_cluster->IsUnknownShape()) { | |||
continue; | |||
} | |||
if (!cluster->IsAdjoinNodes(in_cluster)) { | |||
continue; | |||
} | |||
auto merged_clusters = cluster->MergeAllPathFrom(in_cluster); | |||
GELOGD("Merge all path cluster from %lu to %lu %s.", in_cluster->Id(), cluster->Id(), | |||
ToString(merged_clusters).c_str()); | |||
@@ -80,6 +80,10 @@ class DynamicShapePartitioner { | |||
Status BuildPartitionSubgraph(); | |||
// Clear resource and break circular dependency | |||
void Clear(); | |||
bool IsAdjoinNodes(const std::shared_ptr<Cluster> &other) const { | |||
const auto &out_clusters = other->out_clusters_; | |||
return std::find(out_clusters.begin(), out_clusters.end(), shared_from_this()) != out_clusters.end(); | |||
} | |||
private: | |||
static thread_local size_t unique_id_; | |||
@@ -451,7 +451,7 @@ Status AtomicAddrCleanPass::CompileUnknownGraphOp(const vector<NodePtr> &atomic_ | |||
GE_TIMESTAMP_ADD(UnknownGraphCompileOp); | |||
if (ret != ge::SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Call CompileOp failed, kernel_lib_name:%s, ret:%d", | |||
kernel_lib_name.c_str(), ret); | |||
kernel_lib_name.c_str(), ret); | |||
GELOGE(ret, "Compile atomic op failed, kernel lib name is %s", kernel_lib_name.c_str()); | |||
return ret; | |||
} | |||
@@ -29,7 +29,8 @@ Status AttachStreamLabelPass::Run(ComputeGraphPtr graph) { | |||
std::map<NodePtr, NodePtr> branch_head_nodes; | |||
FindNodes(graph, need_label_nodes, enter_nodes, branch_head_nodes); | |||
for (const auto &node : need_label_nodes) { | |||
GE_CHK_STATUS_RET(UpdateCondBranch(node, branch_head_nodes), "Update cond branch failed, start node:%s.", node->GetName().c_str()); | |||
GE_CHK_STATUS_RET(UpdateCondBranch(node, branch_head_nodes), "Update cond branch failed, start node:%s.", | |||
node->GetName().c_str()); | |||
} | |||
GE_CHK_STATUS_RET(UpdateEnterNode(enter_nodes), "UpdateEnterNode failed."); | |||
@@ -62,7 +62,7 @@ Status BitcastPass::CheckDstDataType(const OpDescPtr op_desc, ge::DataType &dst_ | |||
if (!ge::AttrUtils::GetDataType(op_desc, kAttrNameType, dst_data_type)) { | |||
REPORT_CALL_ERROR("E19999", "Get Attr:%s of op:%s(%s) failed", | |||
kAttrNameType, op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||
kAttrNameType, op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||
GELOGE(PARAM_INVALID, "Node failed to get attribute type."); | |||
return PARAM_INVALID; | |||
} | |||
@@ -166,7 +166,7 @@ Status CondPass::GetCondInfoForWhile(const NodePtr &node, ComputeGraphPtr &graph | |||
if (iter == subgraph_names_to_index.end()) { | |||
REPORT_INNER_ERROR("E19999", "subgraph name:%s not exist in SubgraphNameIndexes map of op:%s(%s), " | |||
"check invalid", ATTR_NAME_WHILE_COND.c_str(), | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||
GELOGE(FAILED, "Get cond_graph index failed, while_node:%s.", node->GetName().c_str()); | |||
return FAILED; | |||
} | |||
@@ -65,13 +65,13 @@ Status CtrlEdgeTransferPass::Run(ge::ComputeGraphPtr graph) { | |||
for (auto &in_control_node : n->GetInControlNodes()) { | |||
GE_CHECK_NOTNULL(in_control_node); | |||
GE_CHK_GRAPH_STATUS_RET(ge::GraphUtils::RemoveEdge(in_control_node->GetOutControlAnchor(), | |||
n->GetInControlAnchor()), "remove edge failed"); | |||
n->GetInControlAnchor()), "remove edge failed"); | |||
for (auto &out_node : n->GetOutNodes()) { | |||
if (out_node == nullptr) { | |||
continue; | |||
} | |||
GE_CHK_GRAPH_STATUS_RET(ge::GraphUtils::AddEdge(in_control_node->GetOutControlAnchor(), | |||
out_node->GetInControlAnchor()), "add edge failed."); | |||
out_node->GetInControlAnchor()), "add edge failed."); | |||
} | |||
} | |||
} | |||
@@ -1,5 +1,5 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
@@ -24,11 +24,12 @@ | |||
#include "common/ge/ge_util.h" | |||
#include "framework/common/types.h" | |||
#include "graph/utils/graph_utils.h" | |||
#include "graph/utils/op_desc_utils.h" | |||
namespace { | |||
const int kAnchorNum = 0; | |||
const int32_t kAnchorAssignRefIndex = 0; | |||
const int32_t kAnchorAssignValueIndex = 1; | |||
const int32_t kAnchorIdentityIndex = 0; | |||
} // namespace | |||
namespace ge { | |||
Status HcclContinuousMemcpyPass::Run(ge::ComputeGraphPtr graph) { | |||
@@ -161,41 +162,23 @@ NodePtr HcclContinuousMemcpyPass::CreateIdentityNode(const ComputeGraphPtr &grap | |||
std::string node_name = pre_node->GetName() + "_" + IDENTITY; | |||
node_name = CheckDuplicateName(node_name); | |||
OpDescPtr op_desc = MakeShared<OpDesc>(node_name.c_str(), IDENTITY); | |||
if (op_desc == nullptr) { | |||
REPORT_CALL_ERROR("E19999", "New OpDesc failed"); | |||
GELOGE(INTERNAL_ERROR, "Create Identity op: MakeShared op_desc fail."); | |||
return nullptr; | |||
} | |||
GELOGI("Create Identity op:%s.", op_desc->GetName().c_str()); | |||
graphStatus ret = op_desc->AddInputDesc("x", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); | |||
if (ret != GRAPH_SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed", | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||
GELOGE(INTERNAL_ERROR, "Create Identity op: add input desc fail."); | |||
return nullptr; | |||
} | |||
ret = op_desc->AddOutputDesc("y", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); | |||
if (ret != GRAPH_SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed", | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||
GELOGE(INTERNAL_ERROR, "Create Identity op: add output desc fail."); | |||
OpDescBuilder op_desc_builder(node_name, IDENTITY); | |||
auto data_desc = pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx()); | |||
auto identity_op_desc = op_desc_builder.AddInput("x", data_desc).AddOutput("y", data_desc).Build(); | |||
if (identity_op_desc == nullptr) { | |||
return nullptr; | |||
} | |||
// because history reason ,this pass can not do work after constant fold so mark it | |||
(void)AttrUtils::SetBool(op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | |||
(void)AttrUtils::SetBool(identity_op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | |||
NodePtr memcpy_node = graph->AddNode(op_desc); | |||
if (memcpy_node == nullptr) { | |||
NodePtr identity_node = graph->AddNode(identity_op_desc); | |||
if (identity_node == nullptr) { | |||
REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); | |||
identity_node->GetName().c_str(), identity_node->GetType().c_str(), graph->GetName().c_str()); | |||
GELOGE(INTERNAL_ERROR, "Insert Identity node fail."); | |||
return nullptr; | |||
} | |||
return memcpy_node; | |||
return identity_node; | |||
} | |||
/// | |||
@@ -256,50 +239,24 @@ Status HcclContinuousMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &gra | |||
Status HcclContinuousMemcpyPass::InsertIdentityBeforeHccl(const ComputeGraphPtr &graph, | |||
const OutDataAnchorPtr &src_out_anchor, | |||
const InDataAnchorPtr &hccl_in_anchor) { | |||
GELOGI("Between op %s and op %s need insert memcpy async op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
GELOGI("Between op %s and op %s need insert identity op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||
NodePtr memcpy_node = CreateIdentityNode(graph, src_out_anchor); | |||
GE_CHECK_NOTNULL(memcpy_node); | |||
NodePtr identity_node = CreateIdentityNode(graph, src_out_anchor); | |||
GE_CHECK_NOTNULL(identity_node); | |||
Status ret1 = src_out_anchor->Unlink(hccl_in_anchor); | |||
if (ret1 != SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", | |||
"Op:%s(%s) out index:%d unlink from op:%s(%s) in index:%d failed", | |||
src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(), | |||
hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | |||
hccl_in_anchor->GetIdx()); | |||
GELOGE(INTERNAL_ERROR, "The op %s Unlink anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||
return FAILED; | |||
} | |||
auto out_data_anchor_0 = memcpy_node->GetOutDataAnchor(kAnchorNum); | |||
GE_CHECK_NOTNULL(out_data_anchor_0); | |||
ret1 = out_data_anchor_0->LinkTo(hccl_in_anchor); | |||
if (ret1 != SUCCESS) { | |||
auto ret = GraphUtils::InsertNodeBefore(hccl_in_anchor, identity_node, kAnchorIdentityIndex, kAnchorIdentityIndex); | |||
if (ret != SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", | |||
"Op:%s(%s) out index:%d link to op:%s(%s) in index:%d failed", | |||
out_data_anchor_0->GetOwnerNode()->GetName().c_str(), | |||
out_data_anchor_0->GetOwnerNode()->GetType().c_str(), out_data_anchor_0->GetIdx(), | |||
"Op:Fail to insert %s(%s) before %s(%s) on index:%d input anchor.", | |||
identity_node->GetName().c_str(), identity_node->GetType().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | |||
hccl_in_anchor->GetIdx()); | |||
GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", memcpy_node->GetName().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||
return FAILED; | |||
} | |||
Status ret = src_out_anchor->LinkTo(memcpy_node->GetInDataAnchor(kAnchorNum)); | |||
if (ret != SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", | |||
"Op:%s(%s) out index:%d link to op:%s(%s) in index:%u failed", | |||
src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(), | |||
memcpy_node->GetName().c_str(), memcpy_node->GetType().c_str(), | |||
kAnchorNum); | |||
GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
memcpy_node->GetName().c_str()); | |||
GELOGE(INTERNAL_ERROR, "Fail to insert %s(%s) before %s(%s) on index:%d input anchor.", | |||
identity_node->GetName().c_str(), identity_node->GetType().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | |||
hccl_in_anchor->GetIdx()); | |||
return FAILED; | |||
} | |||
return SUCCESS; | |||
@@ -1,5 +1,5 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
@@ -1,5 +1,5 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
@@ -24,13 +24,15 @@ | |||
#include "common/ge/ge_util.h" | |||
#include "framework/common/types.h" | |||
#include "graph/utils/graph_utils.h" | |||
#include "graph/utils/op_desc_utils.h" | |||
namespace { | |||
const int32_t kAnchorSize = 1; | |||
const int kAnchorNum = 0; | |||
const int32_t kAnchorAssignRefIndex = 0; | |||
const int32_t kAnchorAssignValueIndex = 1; | |||
const char *const kInputMutable = "_input_mutable"; | |||
const int32_t kAnchorIdentityIndex = 0; | |||
// attr _input_mutable = true means hccl node will modify its input in runtime | |||
const char *const kModifyInput = "_input_mutable"; | |||
} // namespace | |||
namespace ge { | |||
Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { | |||
@@ -58,24 +60,13 @@ Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { | |||
// need to inset memcpy node between. | |||
// also works on situation that input is variable or const. | |||
Status HcclMemcpyPass::MutableInputProcess(const ComputeGraphPtr &graph, const NodePtr node) { | |||
auto op_desc = node->GetOpDesc(); | |||
bool node_input_mutable = false; | |||
if (!AttrUtils::HasAttr(op_desc, kInputMutable)) { | |||
return SUCCESS; | |||
} | |||
if (!AttrUtils::GetBool(op_desc, kInputMutable, node_input_mutable)) { | |||
REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", kInputMutable, | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||
GELOGE(INTERNAL_ERROR, "node:%s get attr:_input_mutable failed.", node->GetName().c_str()); | |||
return FAILED; | |||
} | |||
(void)AttrUtils::GetBool(node->GetOpDesc(), kModifyInput, node_input_mutable); | |||
if (!node_input_mutable) { | |||
return SUCCESS; | |||
} | |||
GELOGI("input mutable hcom op is:%s.", op_desc->GetName().c_str()); | |||
GELOGI("input mutable hcom op is:%s.", node->GetName().c_str()); | |||
for (auto &hccl_in_anchor : node->GetAllInDataAnchors()) { | |||
if (hccl_in_anchor == nullptr) { | |||
continue; | |||
@@ -127,41 +118,23 @@ NodePtr HcclMemcpyPass::CreateIdentityNode(const ComputeGraphPtr &graph, const O | |||
std::string node_name = pre_node->GetName() + "_" + IDENTITY; | |||
node_name = CheckDuplicateName(node_name); | |||
OpDescPtr op_desc = MakeShared<OpDesc>(node_name.c_str(), IDENTITY); | |||
if (op_desc == nullptr) { | |||
REPORT_CALL_ERROR("E19999", "New OpDesc failed"); | |||
GELOGE(INTERNAL_ERROR, "Create Identity op: MakeShared op_desc fail."); | |||
return nullptr; | |||
} | |||
GELOGI("Create Identity op:%s.", op_desc->GetName().c_str()); | |||
graphStatus ret = op_desc->AddInputDesc("x", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); | |||
if (ret != GRAPH_SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed, name:x", | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||
GELOGE(INTERNAL_ERROR, "Create Identity op: add input desc fail."); | |||
return nullptr; | |||
} | |||
ret = op_desc->AddOutputDesc("y", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); | |||
if (ret != GRAPH_SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed, name:y", | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||
GELOGE(INTERNAL_ERROR, "Create Identity op: add output desc fail."); | |||
OpDescBuilder op_desc_builder(node_name, IDENTITY); | |||
auto data_desc = pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx()); | |||
auto identity_op_desc = op_desc_builder.AddInput("x", data_desc).AddOutput("y", data_desc).Build(); | |||
if (identity_op_desc == nullptr) { | |||
return nullptr; | |||
} | |||
// because history reason ,this pass can not do work after constant fold so mark it | |||
(void)AttrUtils::SetBool(op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | |||
(void)AttrUtils::SetBool(identity_op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | |||
NodePtr memcpy_node = graph->AddNode(op_desc); | |||
if (memcpy_node == nullptr) { | |||
NodePtr identity_node = graph->AddNode(identity_op_desc); | |||
if (identity_node == nullptr) { | |||
REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); | |||
identity_node->GetName().c_str(), identity_node->GetType().c_str(), graph->GetName().c_str()); | |||
GELOGE(INTERNAL_ERROR, "Insert Identity node fail."); | |||
return nullptr; | |||
} | |||
return memcpy_node; | |||
return identity_node; | |||
} | |||
/// | |||
@@ -220,49 +193,24 @@ Status HcclMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &graph, const | |||
/// | |||
Status HcclMemcpyPass::InsertIdentityBeforeHccl(const ComputeGraphPtr &graph, const OutDataAnchorPtr &src_out_anchor, | |||
const InDataAnchorPtr &hccl_in_anchor) { | |||
GELOGI("Between op %s and op %s need insert memcpy async op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
GELOGI("Between op %s and op %s need insert identity op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||
NodePtr memcpy_node = CreateIdentityNode(graph, src_out_anchor); | |||
GE_CHECK_NOTNULL(memcpy_node); | |||
NodePtr identity_node = CreateIdentityNode(graph, src_out_anchor); | |||
GE_CHECK_NOTNULL(identity_node); | |||
Status ret1 = src_out_anchor->Unlink(hccl_in_anchor); | |||
if (ret1 != SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", | |||
"Op:%s(%s) out index:%d unlink from op:%s(%s) in index:%d failed", | |||
src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(), | |||
hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetType().c_str(), hccl_in_anchor->GetIdx()); | |||
GELOGE(INTERNAL_ERROR, "The op %s Unlink anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||
return FAILED; | |||
} | |||
auto out_data_anchor_0 = memcpy_node->GetOutDataAnchor(kAnchorNum); | |||
GE_CHECK_NOTNULL(out_data_anchor_0); | |||
ret1 = out_data_anchor_0->LinkTo(hccl_in_anchor); | |||
if (ret1 != SUCCESS) { | |||
auto ret = GraphUtils::InsertNodeBefore(hccl_in_anchor, identity_node, kAnchorIdentityIndex, kAnchorIdentityIndex); | |||
if (ret != SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", | |||
"Op:%s(%s) out index:%d link to op:%s(%s) in index:%d failed", | |||
out_data_anchor_0->GetOwnerNode()->GetName().c_str(), | |||
out_data_anchor_0->GetOwnerNode()->GetType().c_str(), out_data_anchor_0->GetIdx(), | |||
"Op:Fail to insert %s(%s) before %s(%s) on index:%d input anchor.", | |||
identity_node->GetName().c_str(), identity_node->GetType().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | |||
hccl_in_anchor->GetIdx()); | |||
GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", memcpy_node->GetName().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||
return FAILED; | |||
} | |||
Status ret = src_out_anchor->LinkTo(memcpy_node->GetInDataAnchor(kAnchorNum)); | |||
if (ret != SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", | |||
"Op:%s(%s) out index:%d link to op:%s(%s) in index:%u failed", | |||
src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(), | |||
memcpy_node->GetName().c_str(), memcpy_node->GetType().c_str(), | |||
kAnchorNum); | |||
GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
memcpy_node->GetName().c_str()); | |||
GELOGE(INTERNAL_ERROR, "Fail to insert %s(%s) before %s(%s) on index:%d input anchor.", | |||
identity_node->GetName().c_str(), identity_node->GetType().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | |||
hccl_in_anchor->GetIdx()); | |||
return FAILED; | |||
} | |||
return SUCCESS; | |||
@@ -340,13 +288,13 @@ Status HcclMemcpyPass::InsertAssignAfterBroadcastIfNeed(const ComputeGraphPtr &g | |||
} | |||
ret = assign_out_control_anchor->LinkTo(in_data_anchor->GetOwnerNode()->GetInControlAnchor()); | |||
if (ret != SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", | |||
"Op:%s(%s) out index:%d link to op:%s(%s) in index:%d failed", | |||
assign_out_control_anchor->GetOwnerNode()->GetName().c_str(), | |||
assign_out_control_anchor->GetOwnerNode()->GetType().c_str(), assign_out_control_anchor->GetIdx(), | |||
in_data_anchor->GetOwnerNode()->GetName().c_str(), | |||
in_data_anchor->GetOwnerNode()->GetType().c_str(), | |||
in_data_anchor->GetIdx()); | |||
REPORT_CALL_ERROR("E19999", "Op:%s(%s) out index:%d link to op:%s(%s) in index:%d failed", | |||
assign_out_control_anchor->GetOwnerNode()->GetName().c_str(), | |||
assign_out_control_anchor->GetOwnerNode()->GetType().c_str(), | |||
assign_out_control_anchor->GetIdx(), | |||
in_data_anchor->GetOwnerNode()->GetName().c_str(), | |||
in_data_anchor->GetOwnerNode()->GetType().c_str(), | |||
in_data_anchor->GetIdx()); | |||
GELOGE(INTERNAL_ERROR, "The op %s link control anchor %s fail.", | |||
assign_out_control_anchor->GetOwnerNode()->GetName().c_str(), | |||
in_data_anchor->GetOwnerNode()->GetName().c_str()); | |||
@@ -1,5 +1,5 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
@@ -78,8 +78,6 @@ Status InplaceSupportCheckPass::Run(NodePtr &node) { | |||
AddRePassNode(node); | |||
break; | |||
} | |||
GELOGD("InplaceSupportCheckPass success"); | |||
return SUCCESS; | |||
} | |||
} // namespace ge |
@@ -29,7 +29,7 @@ Status MarkGraphUnknownStatusPass::Run(ComputeGraphPtr graph) { | |||
bool forced_unknown = false; | |||
for (const auto &node : graph->GetDirectNode()) { | |||
GE_CHK_GRAPH_STATUS_RET(ge::NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown_shape), | |||
"Get node[%s] shape status failed!", node->GetName().c_str()); | |||
"Get node[%s] shape status failed!", node->GetName().c_str()); | |||
if (is_unknown_shape) { | |||
break; | |||
} | |||
@@ -1020,7 +1020,7 @@ Status MultiBatchClonePass::SetShapeToData(const std::vector<int64_t> &shapes, c | |||
if (!IsGetNextType(data)) { | |||
if (NodeUtils::UpdateInputShape(*data, kDataInIndex, data_shape) != GRAPH_SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Update input desc shape to op:%s(%s) failed, index:%u", | |||
data->GetName().c_str(), data->GetType().c_str(), kDataInIndex); | |||
data->GetName().c_str(), data->GetType().c_str(), kDataInIndex); | |||
GELOGE(INTERNAL_ERROR, "Failed to update input shape for data %s", data->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
@@ -759,7 +759,7 @@ Status MultiBatchPass::AttachStreamLabel(uint32_t batch_idx, const std::string & | |||
GELOGD("Attach stream_label %s to node %s.", stream_label.c_str(), cur_desc->GetName().c_str()); | |||
if (SetStreamLabel(cur_node, stream_label) != SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Set stream_label:%s to op:%s(%s) failed", | |||
stream_label.c_str(), cur_node->GetName().c_str(), cur_node->GetType().c_str()); | |||
stream_label.c_str(), cur_node->GetName().c_str(), cur_node->GetType().c_str()); | |||
GELOGE(FAILED, "Set stream_label failed, node:%s.", cur_node->GetName().c_str()); | |||
return FAILED; | |||
} | |||
@@ -73,7 +73,7 @@ Status NetOutputPass::GetRetvalOutputInfo(const ge::NodePtr &node, | |||
if (iter != targets_.end()) { | |||
targets_.erase(iter); | |||
targets_.insert(src_node_ptr); | |||
GELOGI("node [%s] is in user def targets, do not output result to user!", node->GetName().c_str()); | |||
GELOGI("Node [%s] is in user def targets, do not output result to user!", node->GetName().c_str()); | |||
} | |||
is_include_special_node_ = true; | |||
return SUCCESS; | |||
@@ -105,7 +105,7 @@ Status NetOutputPass::GetOutputNode(const ge::ComputeGraphPtr &graph, std::vecto | |||
for (auto &ele : graph->GetGraphOutNodesInfo()) { | |||
auto iter = targets_.find(ele.first); | |||
if (iter != targets_.end()) { | |||
GELOGI("user set out node [%s] is found in user def targets, out node is prio!", ele.first->GetName().c_str()); | |||
GELOGI("User set out node [%s] is found in user def targets, out node is prior!", ele.first->GetName().c_str()); | |||
targets_.erase(iter); | |||
} | |||
@@ -213,7 +213,7 @@ Status NetOutputPass::UpdateNetOutputDesc(const ge::NodePtr &net_output) { | |||
std::vector<bool> is_input_const; | |||
for (const auto &in_anchor : net_output->GetAllInDataAnchors()) { | |||
GE_CHECK_NOTNULL(in_anchor); | |||
uint32_t index = static_cast<uint32_t>(in_anchor->GetIdx()); | |||
auto index = static_cast<uint32_t>(in_anchor->GetIdx()); | |||
if (index >= net_output_desc->GetAllInputsDesc().size()) { | |||
REPORT_INNER_ERROR("E19999", "Node:%s(%s) has in_anchor index:%u >= its input desc num:%zu, check invalid", | |||
net_output_desc->GetName().c_str(), net_output_desc->GetType().c_str(), index, | |||
@@ -369,10 +369,9 @@ Status NetOutputPass::UnLinkDataAnchorOfNetoutput(const ge::ComputeGraphPtr &gra | |||
if (!CheckNodeIsInOutputNodes(graph, node)) { | |||
ret = in_data_anchor->Unlink(peer_out_anchor); | |||
if (ret != SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", | |||
"Op:%s(%s) out index:%d unlink from op:%s(%s) in index:%d failed", | |||
net_out_node->GetName().c_str(), net_out_node->GetType().c_str(), in_data_anchor->GetIdx(), | |||
node->GetName().c_str(), node->GetType().c_str(), peer_out_anchor->GetIdx()); | |||
REPORT_CALL_ERROR("E19999", "Op:%s(%s) out index:%d unlink from op:%s(%s) in index:%d failed", | |||
net_out_node->GetName().c_str(), net_out_node->GetType().c_str(), in_data_anchor->GetIdx(), | |||
node->GetName().c_str(), node->GetType().c_str(), peer_out_anchor->GetIdx()); | |||
GELOGE(INTERNAL_ERROR, "Unlink peer_out_anchor fail!"); | |||
return ret; | |||
} | |||
@@ -565,7 +564,7 @@ Status NetOutputPass::AddNetOutputNodeToGraph(const ge::ComputeGraphPtr &graph, | |||
GELOGI("[NETOUTPUT PASS] Add net output node succeed"); | |||
return SUCCESS; | |||
} | |||
GELOGI("[NETOUTPUT PASS] Output node size:%lu.", output_nodes_info.size()); | |||
GELOGI("[NETOUTPUT PASS] Output node size:%zu.", output_nodes_info.size()); | |||
if (output_nodes_info.empty()) { | |||
// because retval node is contained by output_nodes_info, here means targets is non-empty | |||
output_node = graph->AddNode(net_output_desc); | |||
@@ -354,7 +354,7 @@ Status NextIterationPass::BreakNextIteration(const NodePtr &next_node, NodePtr & | |||
merge_node->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
if (SetNextIteration(merge_node, next_node->GetName()) != SUCCESS) { | |||
if (SetNextIteration(merge_node, next_node) != SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Set attr NEXT_ITERATION value:%s to node:%s(%s) failed", | |||
next_node->GetName().c_str(), merge_node->GetName().c_str(), merge_node->GetType().c_str()); | |||
GELOGE(INTERNAL_ERROR, "Set attr NEXT_ITERATION for node %s failed.", merge_node->GetName().c_str()); | |||
@@ -170,7 +170,7 @@ Status PassUtils::SetOutNodeWeight(const OutDataAnchorPtr &out_data_anchor, cons | |||
// restore control inputs to dynamically added constant ops, if any | |||
for (const auto &src_out_control_anchor : src_out_control_anchors) { | |||
GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(src_out_control_anchor, dynamic_const_node->GetInControlAnchor()), | |||
"add edge failed"); | |||
"add edge failed"); | |||
} | |||
} | |||
@@ -51,7 +51,7 @@ std::vector<std::pair<NodePtr, InDataAnchorPtr>> GetOutDataNodesByIndex(const No | |||
auto out_anchor = node->GetOutDataAnchor(index); | |||
if (out_anchor == nullptr) { | |||
REPORT_INNER_ERROR("E19999", "Node:%s(%s) has no index:%d out data anchor, check invalid", | |||
node->GetName().c_str(), node->GetType().c_str(), index); | |||
node->GetName().c_str(), node->GetType().c_str(), index); | |||
GELOGE(PARAM_INVALID, "Failed to get out data nodes of index %d from node %s, the anchor does not exists", index, | |||
node->GetName().c_str()); | |||
return {}; | |||
@@ -1077,9 +1077,9 @@ graphStatus TransOpWithoutReshapeFusionPass::RelinkControlEdge(const int index, | |||
peer_in_anchor->GetOwnerNode()->GetName().c_str()); | |||
if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutControlAnchor(), peer_in_anchor) != GRAPH_SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed", | |||
new_trans_nodes.back()->GetName().c_str(), new_trans_nodes.back()->GetType().c_str(), | |||
peer_in_anchor->GetOwnerNode()->GetName().c_str(), | |||
peer_in_anchor->GetOwnerNode()->GetType().c_str()); | |||
new_trans_nodes.back()->GetName().c_str(), new_trans_nodes.back()->GetType().c_str(), | |||
peer_in_anchor->GetOwnerNode()->GetName().c_str(), | |||
peer_in_anchor->GetOwnerNode()->GetType().c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
} | |||
@@ -1103,9 +1103,9 @@ graphStatus TransOpWithoutReshapeFusionPass::RelinkControlEdge(const int index, | |||
peer_in_anchor->GetOwnerNode()->GetName().c_str()); | |||
if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutControlAnchor(), peer_in_anchor) != GRAPH_SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed", | |||
new_trans_nodes.back()->GetName().c_str(), new_trans_nodes.back()->GetType().c_str(), | |||
peer_in_anchor->GetOwnerNode()->GetName().c_str(), | |||
peer_in_anchor->GetOwnerNode()->GetType().c_str()); | |||
new_trans_nodes.back()->GetName().c_str(), new_trans_nodes.back()->GetType().c_str(), | |||
peer_in_anchor->GetOwnerNode()->GetName().c_str(), | |||
peer_in_anchor->GetOwnerNode()->GetType().c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
} | |||
@@ -87,10 +87,10 @@ Status ByPassTransNode(NodePtr &trans_node, NodePtr &ref_node) { | |||
ret = GraphUtils::AddEdge(prev_trans_node_out_anchor, ref_in_anchor); | |||
if (ret != GRAPH_SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:0) failed", | |||
prev_trans_node_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
prev_trans_node_out_anchor->GetOwnerNode()->GetType().c_str(), | |||
prev_trans_node_out_anchor->GetIdx(), | |||
ref_node->GetName().c_str(), ref_node->GetType().c_str()); | |||
prev_trans_node_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
prev_trans_node_out_anchor->GetOwnerNode()->GetType().c_str(), | |||
prev_trans_node_out_anchor->GetIdx(), | |||
ref_node->GetName().c_str(), ref_node->GetType().c_str()); | |||
GELOGE(INTERNAL_ERROR, | |||
"Failed to add edge between ref node %s " | |||
"and the prev node of trans node %s", | |||
@@ -241,7 +241,7 @@ NodePtr CreateTransNode(const std::string &name, const std::string &node_type, c | |||
ret = op_desc->AddInputDesc(shape_desc->GetOutputDesc(0)); | |||
if (ret != GRAPH_SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Add input desc into op:%s(%s) failed", | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||
GELOGE(INTERNAL_ERROR, "Failed to add the first input for reshape %s", name.c_str()); | |||
return nullptr; | |||
} | |||
@@ -837,7 +837,7 @@ Status ProcessInputNC1HWC0DynShape(NodePtr &node_ptr, bool &is_dynamic_batch, No | |||
old_shape = switchn_output->GetShape(); | |||
if (ModifyFormatAndShapeForSingleTensor(switchn_output) != SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Modify format and shape of output:%u in op:%s(%s) failed", i, | |||
switchn_op_desc->GetName().c_str(), switchn_op_desc->GetType().c_str()); | |||
switchn_op_desc->GetName().c_str(), switchn_op_desc->GetType().c_str()); | |||
GELOGE(INTERNAL_ERROR, "modify format and shape failed"); | |||
return FAILED; | |||
} | |||
@@ -1266,8 +1266,8 @@ Status MultiBatchGraphCopyer::LinkNodeToMerge(const NodePtr &node, int out_index | |||
auto ret = GraphUtils::AddEdge(src_node->GetOutDataAnchor(out_index), merge->GetInDataAnchor(i)); | |||
if (ret != GRAPH_SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%zu) failed", | |||
src_node->GetName().c_str(), src_node->GetType().c_str(), out_index, | |||
merge->GetName().c_str(), merge->GetType().c_str(), i); | |||
src_node->GetName().c_str(), src_node->GetType().c_str(), out_index, | |||
merge->GetName().c_str(), merge->GetType().c_str(), i); | |||
GELOGE(INTERNAL_ERROR, | |||
"Failed to add edge between copyed node %s(%d) to inserted merge node %s(%zu), error-code %u", | |||
copyed_nodes[i]->GetName().c_str(), out_index, merge->GetName().c_str(), i, ret); | |||
@@ -306,28 +306,15 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() { | |||
return task_context_; | |||
} | |||
void NodeState::ResetContext(int group) { | |||
SetGroup(group); | |||
if (loop_count_ == 0) { | |||
++loop_count_; | |||
return; | |||
} | |||
++loop_count_; | |||
if (loop_count_ == UINT64_MAX) { | |||
loop_count_ = 1; | |||
} | |||
void NodeState::ResetContext(uint64_t loop_count) { | |||
loop_count_ = loop_count; | |||
switch_index_ = -1; | |||
subgraph_context_->ResetContext(node_item_->node); | |||
GELOGD("Node[%s] in while loop, current loop: %lu, merge index: %d", GetName().c_str(), loop_count_, merge_index_); | |||
} | |||
void NodeState::ResetSchedule() { | |||
std::lock_guard<std::mutex> lk(mu_); | |||
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | |||
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | |||
GELOGD("[%s] set schedule for root nodes, data: %u, ctrl: %u", GetName().c_str(), data_scheduled_, ctrl_scheduled_); | |||
GELOGD("[%s] in while loop, loop count: %lu, data scheduled: %u, ctrl scheduled: %u, merge index: %d", | |||
GetName().c_str(), loop_count_, data_scheduled_, ctrl_scheduled_, merge_index_); | |||
} | |||
Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &ready) const { | |||
@@ -335,14 +322,14 @@ Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &rea | |||
for (const auto &node : node_item_->data_send_) { | |||
const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node); | |||
GE_CHECK_NOTNULL(dst_node_state); | |||
dst_node_state->SetDataSchedule(node_item_, ready); | |||
dst_node_state->SetDataSchedule(*this, ready); | |||
} | |||
// Schedule ctrl output. | |||
for (const auto &node : node_item_->ctrl_send_) { | |||
const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node); | |||
GE_CHECK_NOTNULL(dst_node_state); | |||
dst_node_state->SetCtrlSchedule(node_item_, ready); | |||
dst_node_state->SetCtrlSchedule(*this, ready); | |||
} | |||
// Schedule switch group. | |||
@@ -351,7 +338,7 @@ Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &rea | |||
for (const auto &node : node_item_->switch_groups_[switch_index_]) { | |||
const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node); | |||
GE_CHECK_NOTNULL(dst_node_state); | |||
dst_node_state->SetCtrlSchedule(node_item_, ready); | |||
dst_node_state->SetCtrlSchedule(*this, ready); | |||
} | |||
} | |||
@@ -359,36 +346,44 @@ Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &rea | |||
} | |||
bool NodeState::IsScheduleReady() const { | |||
GELOGD("[%s] data[input: %zu, scheduled: %u], ctrl[input: %zu, scheduled: %u]", GetName().c_str(), | |||
node_item_->data_recv_.size(), data_scheduled_, node_item_->ctrl_recv_.size(), ctrl_scheduled_); | |||
if (ctrl_scheduled_ != node_item_->ctrl_recv_.size()) { | |||
return false; | |||
} | |||
GELOGD("[%s] loop[%lu] data[input: %zu, scheduled: %u], ctrl[input: %zu+%zu, scheduled: %u]", | |||
GetName().c_str(), loop_count_, node_item_->data_recv_.size(), data_scheduled_, | |||
node_item_->ctrl_recv_.size(), node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_); | |||
if (node_item_->IsMergeOp()) { | |||
if (ctrl_scheduled_ != node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1) + node_item_->ctrl_recv_.size()) { | |||
return false; | |||
} | |||
return data_scheduled_ > 0; | |||
} | |||
if (ctrl_scheduled_ != node_item_->ctrl_recv_.size()) { | |||
return false; | |||
} | |||
// Exit may feed loop times... | |||
return data_scheduled_ >= node_item_->data_recv_.size(); | |||
} | |||
void NodeState::SetDataSchedule(const NodeItem *node_item, const std::function<void(const NodeItem *)> &ready) { | |||
GELOGD("[%s] data schedule node[%s], data num: %zu, current scheduled: %u, ctrl num: %zu, current scheduled: %u", | |||
node_item->node_name.c_str(), GetName().c_str(), node_item_->data_recv_.size(), data_scheduled_, | |||
node_item_->ctrl_recv_.size(), ctrl_scheduled_); | |||
void NodeState::SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready) { | |||
GELOGD("[%s] data schedule node[%s], data num: %zu, current scheduled: %u, ctrl num: %zu+%zu, current scheduled: %u", | |||
node_state.GetName().c_str(), GetName().c_str(), node_item_->data_recv_.size(), data_scheduled_, | |||
node_item_->ctrl_recv_.size(), node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_); | |||
std::lock_guard<std::mutex> lk(mu_); | |||
if (loop_count_ != node_state.loop_count_) { | |||
ResetContext(node_state.loop_count_); | |||
} | |||
++data_scheduled_; | |||
if (node_item_->IsMergeOp()) { | |||
const auto it = node_item_->data_recv_.find(node_item); | |||
const auto it = node_item_->data_recv_.find(node_state.node_item_); | |||
if (it != node_item_->data_recv_.end()) { | |||
merge_index_ = it->second; | |||
(void)AttrUtils::SetInt(node_item_->node->GetOpDesc(), ATTR_NAME_MERGE_INPUT_INDEX, it->second); | |||
GELOGD("[%s] scheduled, [%s] set merge index: %d", node_item->node_name.c_str(), GetName().c_str(), it->second); | |||
GELOGD("[%s] scheduled, [%s] set merge index: %d", node_state.GetName().c_str(), GetName().c_str(), it->second); | |||
} else { | |||
GELOGW("[%s] scheduled, [%s] not followed", node_item->node_name.c_str(), GetName().c_str()); | |||
GELOGW("[%s] scheduled, [%s] not followed", node_state.GetName().c_str(), GetName().c_str()); | |||
} | |||
} | |||
@@ -397,12 +392,15 @@ void NodeState::SetDataSchedule(const NodeItem *node_item, const std::function<v | |||
} | |||
} | |||
void NodeState::SetCtrlSchedule(const NodeItem *node_item, const std::function<void(const NodeItem *)> &ready) { | |||
GELOGD("[%s] ctrl schedule node[%s], data num: %zu, current scheduled: %u, ctrl num: %zu, current scheduled: %u", | |||
node_item->node_name.c_str(), GetName().c_str(), node_item_->data_recv_.size(), data_scheduled_, | |||
node_item_->ctrl_recv_.size(), ctrl_scheduled_); | |||
void NodeState::SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready) { | |||
GELOGD("[%s] ctrl schedule node[%s], data num: %zu, current scheduled: %u, ctrl num: %zu+%zu, current scheduled: %u", | |||
node_state.GetName().c_str(), GetName().c_str(), node_item_->data_recv_.size(), data_scheduled_, | |||
node_item_->ctrl_recv_.size(), node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_); | |||
std::lock_guard<std::mutex> lk(mu_); | |||
if (loop_count_ != node_state.loop_count_) { | |||
ResetContext(node_state.loop_count_); | |||
} | |||
++ctrl_scheduled_; | |||
if (IsScheduleReady()) { | |||
@@ -410,6 +408,21 @@ void NodeState::SetCtrlSchedule(const NodeItem *node_item, const std::function<v | |||
} | |||
} | |||
void NodeState::RunLoopNext() { | |||
GELOGD("Node[%s] run in loop, current count: %lu", GetName().c_str(), loop_count_); | |||
std::lock_guard<std::mutex> lk(mu_); | |||
++loop_count_; | |||
if (loop_count_ == UINT64_MAX) { | |||
loop_count_ = 1; | |||
} | |||
} | |||
void NodeState::RunLoopExit() { | |||
GELOGD("Node[%s] run in loop, current count: %lu", GetName().c_str(), loop_count_); | |||
std::lock_guard<std::mutex> lk(mu_); | |||
loop_count_ = 0; | |||
} | |||
void NodeState::SetScheduleFuture(std::future<Status> &&future) { | |||
schedule_future_ = std::move(future); | |||
} | |||
@@ -112,9 +112,8 @@ struct NodeState { | |||
return node_item_->IsControlFlowOp() || node_item_->shape_inference_type >= DEPEND_SHAPE_RANGE; | |||
} | |||
void ResetContext(int group); | |||
void ResetSchedule(); | |||
void RunLoopNext(); | |||
void RunLoopExit(); | |||
Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; | |||
@@ -166,8 +165,9 @@ struct NodeState { | |||
private: | |||
bool IsScheduleReady() const; | |||
void SetDataSchedule(const NodeItem *node_item, const std::function<void(const NodeItem *)> &ready); | |||
void SetCtrlSchedule(const NodeItem *node_item, const std::function<void(const NodeItem *)> &ready); | |||
void SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); | |||
void SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); | |||
void ResetContext(uint64_t loop_count); | |||
const NodeItem *node_item_ = nullptr; | |||
std::shared_ptr<NodeTask> kernel_task_ = nullptr; | |||
@@ -46,6 +46,10 @@ Status SubgraphContext::Init() { | |||
return SUCCESS; | |||
} | |||
void SubgraphContext::SetGroup(int group) { | |||
group_ = group; | |||
} | |||
void SubgraphContext::ResetContext(const NodePtr &node) { | |||
node_done_manager_.Reset(node); | |||
} | |||
@@ -84,7 +88,8 @@ NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) { | |||
auto &node_state = node_states_[node_item]; | |||
if (node_state == nullptr) { | |||
const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); | |||
node_state = std::move(std::unique_ptr<NodeState>(new(std::nothrow)NodeState(*node_item, this))); | |||
node_state.reset(new(std::nothrow)NodeState(*node_item, this)); | |||
node_state->SetGroup(group_); | |||
(void)guard; | |||
} | |||
GELOGD("[%s] unlock for write", node_item->NodeName().c_str()); | |||
@@ -34,6 +34,7 @@ class SubgraphContext { | |||
~SubgraphContext(); | |||
Status Init(); | |||
void SetGroup(int group); | |||
void ResetContext(const NodePtr &node); | |||
void Reset(); | |||
NodeStatePtr GetOrCreateNodeState(const NodeItem *node_item); | |||
@@ -58,6 +59,7 @@ class SubgraphContext { | |||
std::vector<TensorValue> all_outputs_; | |||
NodeDoneManager node_done_manager_; | |||
std::unordered_map<const NodeItem *, NodeStatePtr> node_states_; | |||
int group_ = -1; | |||
}; | |||
} // namespace hybrid | |||
} // namespace ge | |||
@@ -242,7 +242,6 @@ Status SubgraphExecutor::PrepareNode(const NodeItem &node_item, int group) { | |||
auto node_state = subgraph_context_->GetOrCreateNodeState(&node_item); | |||
GE_CHECK_NOTNULL(node_state); | |||
node_state->ResetContext(group); | |||
auto p_node_state = node_state.get(); | |||
if (node_item.node_type == NETOUTPUT) { | |||
@@ -367,7 +366,6 @@ Status SubgraphExecutor::NodeScheduled(NodeState *node_state) { | |||
}; | |||
GE_CHK_STATUS_RET_NOLOG(node_state->NodeScheduled(callback)); | |||
node_state->ResetSchedule(); | |||
RECORD_CALLBACK_EVENT(context_, node_state->GetName().c_str(), "[NodeScheduled] End"); | |||
return SUCCESS; | |||
}); | |||
@@ -539,6 +537,7 @@ Status SubgraphExecutor::LaunchTasks() { | |||
Status SubgraphExecutor::ScheduleTasks(int group) { | |||
GELOGD("[%s] Start to schedule prepare workers.", graph_item_->GetName().c_str()); | |||
subgraph_context_->SetGroup(group); | |||
auto prepare_future = std::async(std::launch::async, [&]() -> Status { | |||
GetContext().SetSessionId(context_->session_id); | |||
GetContext().SetContextId(context_->context_id); | |||
@@ -21,6 +21,7 @@ | |||
#include "graph/ge_context.h" | |||
#include "graph/build/memory/var_mem_assign_util.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/common/omg_util.h" | |||
#include "graph/load/model_manager/model_utils.h" | |||
#include "graph/load/model_manager/model_manager.h" | |||
#include "graph/manager/graph_var_manager.h" | |||
@@ -43,8 +44,9 @@ const uint64_t kProfilingBpEndLogid = 2U; | |||
const uint64_t kProfilingIterEndLogid = 65535U; | |||
const int kBytes = 8; | |||
const int kDecimal = 10; | |||
const uint8_t kStreamActiveIdx = 0; | |||
const uint8_t kStreamActiveNum = 1; | |||
const uint8_t kLoopEnterIdx = 0; | |||
const uint8_t kLoopIterationIdx = 1; | |||
const uint8_t kLoopMergeSize = 2; | |||
const uint8_t kStreamSwitchIdx = 1; | |||
const uint8_t kStreamSwitchNum = 2; | |||
const uint32_t kStringHeadElems = 2; | |||
@@ -57,6 +59,10 @@ const char *const kProfilingArNode = "ProfilingAllReduceNode"; | |||
const char *const kEngineNameRts = "DNN_VM_RTS_OP_STORE"; | |||
const char *const kForceInfershape = "_force_infershape_when_running"; | |||
const std::set<std::string> kExecutionDependentTypes{ IF, STATELESSIF, CASE, STREAMSWITCH }; | |||
const std::set<std::string> kMergeInputSkipTypes{ STREAMACTIVE, STREAMSWITCH, CONSTANT, CONSTANTOP }; | |||
const std::set<std::string> kStreamActiveTypes{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; | |||
Status SetOutputNameAttr(ComputeGraph &graph) { | |||
vector<string> output_names; | |||
for (const auto &node : graph.GetDirectNode()) { | |||
@@ -389,7 +395,7 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s | |||
} | |||
// cond or branch need to be prepared before the execution of IF or CASE | |||
if (node_item.node_type == IF || node_item.node_type == STATELESSIF || node_item.node_type == CASE) { | |||
if (kExecutionDependentTypes.count(node_item.node_type) > 0) { | |||
auto src_node = NodeUtils::GetInDataNodeByIndex(*ge_node, 0); // cond input | |||
GE_CHECK_NOTNULL(src_node); | |||
auto src_node_item = MutableNodeItem(src_node); | |||
@@ -575,7 +581,7 @@ Status HybridModelBuilder::MergeInputNodes(ComputeGraph &graph) { | |||
auto in_nodes = root_node->GetInAllNodes(); | |||
std::set<NodePtr> in_node_set(in_nodes.begin(), in_nodes.end()); | |||
for (auto &in_control_node : wrapped_node->GetInControlNodes()) { | |||
if (in_node_set.count(in_control_node) == 0) { | |||
if (in_node_set.count(in_control_node) == 0 && kMergeInputSkipTypes.count(root_node->GetType()) == 0) { | |||
GELOGD("[%s] Restore control edge to [%s]", in_control_node->GetName().c_str(), root_node->GetName().c_str()); | |||
GE_CHECK_NOTNULL(in_control_node->GetOutControlAnchor()); | |||
(void) in_control_node->GetOutControlAnchor()->LinkTo(root_node->GetInControlAnchor()); | |||
@@ -2282,8 +2288,6 @@ Status HybridModelBuilder::RelinkNextIteration() { | |||
} | |||
} | |||
stream_merge_op_nodes_.clear(); | |||
next_iteration_op_nodes_.clear(); | |||
return SUCCESS; | |||
} | |||
@@ -2371,10 +2375,12 @@ Status HybridModelBuilder::BuildControlFlowGroup(GraphItem &graph_item, const No | |||
} | |||
Status HybridModelBuilder::CreateNormalNodeGroup(const NodePtr &node, NodeItem *node_item) { | |||
const auto out_ctrl_anchor = node->GetOutControlAnchor(); | |||
for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { | |||
const auto &dst_node = peer_in_anchor->GetOwnerNode(); | |||
for (const auto &dst_node : node->GetOutControlNodes()) { | |||
GE_CHECK_NOTNULL(dst_node); | |||
if ((dst_node->GetType() == STREAMACTIVE) && (kStreamActiveTypes.count(node->GetType()) == 0)) { | |||
GELOGI("[%s] ignore control to [%s]", node->GetName().c_str(), dst_node->GetName().c_str()); | |||
continue; | |||
} | |||
NodeItem *dst_node_item = nullptr; | |||
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | |||
@@ -2384,27 +2390,80 @@ Status HybridModelBuilder::CreateNormalNodeGroup(const NodePtr &node, NodeItem * | |||
return SUCCESS; | |||
} | |||
Status HybridModelBuilder::CreateMergeEnterGroup(const NodePtr &node, NodeItem *node_item) { | |||
// Enter --> StreamActive --> StreamMerge | |||
for (const auto &dst_node : node->GetOutControlNodes()) { | |||
GE_CHECK_NOTNULL(dst_node); | |||
NodeItem *dst_node_item = nullptr; | |||
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | |||
"[%s] failed to get or create node item", dst_node->GetName().c_str()); | |||
// Set Enter Control to StreamMerge as Group 0. | |||
dst_node_item->switch_groups_.resize(kLoopMergeSize); | |||
dst_node_item->SetMergeCtrl(node_item, kLoopEnterIdx); | |||
} | |||
return SUCCESS; | |||
} | |||
Status HybridModelBuilder::CreateMergeIterationGroup(const NodePtr &node, NodeItem *node_item) { | |||
// NextIteration --> StreamActive {-->} StreamMerge | |||
std::string node_name; | |||
for (const auto &src_node : node->GetInControlNodes()) { | |||
GE_CHECK_NOTNULL(src_node); | |||
if (kNextIterationOpTypes.count(src_node->GetType()) == 0) { | |||
GELOGI("[%s] Skip Not NextIteration node [%s]", node->GetName().c_str(), src_node->GetName().c_str()); | |||
continue; | |||
} | |||
if (!AttrUtils::GetStr(src_node->GetOpDesc(), ATTR_NAME_NEXT_ITERATION, node_name)) { | |||
GELOGE(INTERNAL_ERROR, "[%s] input node [%s] expect attribute[%s] not found", | |||
node->GetName().c_str(), src_node->GetName().c_str(), ATTR_NAME_NEXT_ITERATION.c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
const auto it = stream_merge_op_nodes_.find(node_name); | |||
if (it == stream_merge_op_nodes_.end()) { | |||
GELOGE(INTERNAL_ERROR, "[%s] expect StreamMerge[%s] not found", node->GetName().c_str(), node_name.c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
const auto &dst_node = it->second; | |||
GE_CHECK_NOTNULL(dst_node); | |||
NodeItem *dst_node_item = nullptr; | |||
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), "[%s] failed to get or create node item", | |||
dst_node->GetName().c_str()); | |||
// Set NextIteration Control to StreamMerge as Group 1. | |||
dst_node_item->SetMergeCtrl(node_item, kLoopIterationIdx); | |||
} | |||
return SUCCESS; | |||
} | |||
Status HybridModelBuilder::CreateStreamActiveGroup(const NodePtr &node, NodeItem *node_item) { | |||
if (node_item->node_type != STREAMACTIVE) { | |||
GELOGE(INTERNAL_ERROR, "Called by %s is invalid", node_item->node_type.c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
node_item->switch_groups_.resize(kStreamActiveNum); | |||
const auto &out_ctrl_anchor = node->GetOutControlAnchor(); | |||
for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { | |||
const auto &dst_node = peer_in_anchor->GetOwnerNode(); | |||
GE_CHECK_NOTNULL(dst_node); | |||
if (dst_node->GetType() == STREAMMERGE) { | |||
GELOGI("[%s] skip control node: %s", node->GetName().c_str(), dst_node->GetName().c_str()); | |||
continue; | |||
} | |||
const auto ctrl_nodes = node->GetInControlNodes(); | |||
if (ctrl_nodes.empty()) { | |||
GELOGW("Skip no in control node: %s", node->GetName().c_str()); | |||
return SUCCESS; | |||
} | |||
NodeItem *dst_node_item = nullptr; | |||
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | |||
"[%s] failed to get or create node item", dst_node->GetName().c_str()); | |||
node_item->SetCtrlSend(dst_node_item, kStreamActiveIdx); | |||
const auto IsEnterNode = [](const NodePtr &n) { | |||
return kEnterOpTypes.count(n->GetType()) > 0; | |||
}; | |||
const auto IsIterationNode = [](const NodePtr &n) { | |||
return kNextIterationOpTypes.count(n->GetType()) > 0; | |||
}; | |||
if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsEnterNode)) { | |||
// Enter --> StreamActive --> StreamMerge | |||
return CreateMergeEnterGroup(node, node_item); | |||
} else if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsIterationNode)) { | |||
// NextIteration --> StreamActive {-->} StreamMerge | |||
return CreateMergeIterationGroup(node, node_item); | |||
} | |||
return SUCCESS; | |||
} | |||
@@ -2416,11 +2475,8 @@ Status HybridModelBuilder::CreateStreamSwitchGroup(const NodePtr &node, NodeItem | |||
// Consider as two groups, group[0] set empty for false, group[1] for true. | |||
node_item->switch_groups_.resize(kStreamSwitchNum); | |||
const auto &out_ctrl_anchor = node->GetOutControlAnchor(); | |||
for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { | |||
const auto &dst_node = peer_in_anchor->GetOwnerNode(); | |||
for (const auto &dst_node : node->GetOutControlNodes()) { | |||
GE_CHECK_NOTNULL(dst_node); | |||
NodeItem *dst_node_item = nullptr; | |||
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | |||
"[%s] failed to get or create node item", dst_node->GetName().c_str()); | |||
@@ -2447,20 +2503,17 @@ Status HybridModelBuilder::CreateStreamSwitchNGroup(const NodePtr &node, NodeIte | |||
} | |||
node_item->switch_groups_.resize(batch_num); | |||
const auto &out_ctrl_anchor = node->GetOutControlAnchor(); | |||
for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { | |||
const auto &dst_node = peer_in_anchor->GetOwnerNode(); | |||
for (const auto &dst_node : node->GetOutControlNodes()) { | |||
GE_CHECK_NOTNULL(dst_node); | |||
std::string batch_label; | |||
if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label)) { | |||
GELOGE(INTERNAL_ERROR, "[%s] Get ATTR_NAME_BATCH_LABEL failed", node->GetName().c_str()); | |||
if (!AttrUtils::GetStr(dst_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label)) { | |||
GELOGE(INTERNAL_ERROR, "[%s] Get ATTR_NAME_BATCH_LABEL failed", dst_node->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
std::string::size_type pos = batch_label.rfind("_"); | |||
if (pos == std::string::npos) { | |||
GELOGW("[%s] Separator not found in batch label: %s.", node->GetName().c_str(), batch_label.c_str()); | |||
GELOGW("[%s] Separator not found in batch label: %s.", dst_node->GetName().c_str(), batch_label.c_str()); | |||
continue; | |||
} | |||
@@ -2486,7 +2539,7 @@ Status HybridModelBuilder::CreateNextIterationGroup(const NodePtr &node, NodeIte | |||
return INTERNAL_ERROR; | |||
} | |||
return SUCCESS; | |||
return CreateNormalNodeGroup(node, node_item); | |||
} | |||
Status HybridModelBuilder::CreateSwitchGroup(const NodePtr &node, NodeItem *node_item) { | |||
@@ -2495,11 +2548,8 @@ Status HybridModelBuilder::CreateSwitchGroup(const NodePtr &node, NodeItem *node | |||
return INTERNAL_ERROR; | |||
} | |||
const auto &out_ctrl_anchor = node->GetOutControlAnchor(); | |||
for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { | |||
const auto &dst_node = peer_in_anchor->GetOwnerNode(); | |||
for (const auto &dst_node : node->GetOutControlNodes()) { | |||
GE_CHECK_NOTNULL(dst_node); | |||
NodeItem *dst_node_item = nullptr; | |||
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | |||
"[%s] failed to get or create node item", dst_node->GetName().c_str()); | |||
@@ -2509,11 +2559,8 @@ Status HybridModelBuilder::CreateSwitchGroup(const NodePtr &node, NodeItem *node | |||
// Group switch flow by out put data. | |||
node_item->switch_groups_.resize(SWITCH_OUTPUT_NUM); | |||
for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { | |||
const auto &out_anchor = node->GetOutDataAnchor(i); | |||
for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) { | |||
const auto &dst_node = peer_in_anchor->GetOwnerNode(); | |||
for (const auto &dst_node : node->GetOutDataNodes()) { | |||
GE_CHECK_NOTNULL(dst_node); | |||
NodeItem *dst_node_item = nullptr; | |||
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | |||
"[%s] failed to get or create node item", dst_node->GetName().c_str()); | |||
@@ -99,6 +99,8 @@ class HybridModelBuilder { | |||
Status BuildProfilingControl(GraphItem &graph_item, const std::map<size_t, std::pair<uint32_t, uint32_t>> &nodes); | |||
Status BuildControlFlowGroup(GraphItem &graph_item, const NodePtr &node, NodeItem *node_item); | |||
Status CreateNormalNodeGroup(const NodePtr &node, NodeItem *node_item); | |||
Status CreateMergeEnterGroup(const NodePtr &node, NodeItem *node_item); | |||
Status CreateMergeIterationGroup(const NodePtr &node, NodeItem *node_item); | |||
Status CreateStreamActiveGroup(const NodePtr &node, NodeItem *node_item); | |||
Status CreateStreamSwitchGroup(const NodePtr &node, NodeItem *node_item); | |||
Status CreateStreamSwitchNGroup(const NodePtr &node, NodeItem *node_item); | |||
@@ -34,8 +34,8 @@ const std::set<std::string> kControlOpTypes{ | |||
}; | |||
const std::set<std::string> kControlFlowOpTypes{ | |||
STREAMACTIVE, STREAMSWITCH, STREAMSWITCHN, LABELGOTO, LABELGOTOEX, LABELSWITCH, LABELSWITCHBYINDEX, | |||
NEXTITERATION, REFNEXTITERATION | |||
STREAMACTIVE, STREAMSWITCH, STREAMSWITCHN, NEXTITERATION, REFNEXTITERATION, EXIT, REFEXIT, | |||
LABELGOTO, LABELGOTOEX, LABELSWITCH, LABELSWITCHBYINDEX | |||
}; | |||
const std::set<std::string> kMergeOpTypes{ | |||
@@ -401,6 +401,11 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { | |||
if (is_root_node_) { | |||
node_item->root_data_.emplace(this); | |||
} | |||
// If Enter feed Not Merge, take as root Node. | |||
if ((kEnterOpTypes.count(node_type) > 0) && (node_item->node_type != STREAMMERGE)) { | |||
node_item->root_data_.emplace(this); | |||
node_item->enter_inside_.emplace(anchor_index); | |||
} | |||
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | |||
} | |||
@@ -416,10 +421,31 @@ void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { | |||
if (is_root_node_) { | |||
node_item->root_ctrl_.emplace(this); | |||
} | |||
// If Enter feed control signal, take as root Node. | |||
if (kEnterOpTypes.count(node_type) > 0) { | |||
node_item->root_ctrl_.emplace(this); | |||
} | |||
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | |||
} | |||
void NodeItem::SetMergeCtrl(NodeItem *node_item, uint32_t merge_index) { | |||
if (merge_index >= switch_groups_.size()) { | |||
GELOGE(FAILED, "[%s] group size: %zu, merge index: %u", NodeName().c_str(), switch_groups_.size(), merge_index); | |||
return; | |||
} | |||
// this is StreamMerge node, node_item is StreamActive node. | |||
std::vector<const NodeItem *> &switch_group = switch_groups_[merge_index]; | |||
switch_group.emplace_back(node_item); | |||
node_item->ctrl_send_.emplace(this); | |||
GELOGI("Node[%s] will control node[%s]", node_item->NodeName().c_str(), NodeName().c_str()); | |||
} | |||
size_t NodeItem::GetMergeCtrl(uint32_t merge_index) const { | |||
return (merge_index < switch_groups_.size()) ? switch_groups_[merge_index].size() : 0; | |||
} | |||
OptionalMutexGuard::OptionalMutexGuard(std::mutex *mutex, const string &name) : mu_(mutex), name_(name) { | |||
if (mu_ != nullptr) { | |||
GELOGD("lock for %s", name_.c_str()); | |||
@@ -98,6 +98,8 @@ struct NodeItem { | |||
void SetDataSend(NodeItem *node_item, int anchor_index); | |||
void SetCtrlSend(NodeItem *node_item, uint32_t switch_index); | |||
void SetMergeCtrl(NodeItem *node_item, uint32_t merge_index); | |||
size_t GetMergeCtrl(uint32_t merge_index) const; | |||
OptionalMutexGuard MutexGuard(const std::string &name) const { | |||
return OptionalMutexGuard(copy_mu_.get(), name + "_" + node_name); | |||
@@ -140,6 +142,7 @@ struct NodeItem { | |||
std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | |||
std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from | |||
std::vector<std::vector<const NodeItem *>> switch_groups_; // Send ctrl notify to | |||
std::set<int> enter_inside_; // Enter feed loop inside Node, Not cross Merge. | |||
std::shared_ptr<NodeTask> kernel_task; | |||
std::unique_ptr<FusedSubgraph> fused_subgraph; | |||
@@ -420,9 +420,8 @@ Status AiCoreOpTask::CalcTilingInfo(const NodePtr &node, OpRunInfo &tiling_info) | |||
} | |||
Status AiCoreOpTask::UpdateArgs(TaskContext &task_context) { | |||
size_t expected_arg_count = task_context.NumInputs() + task_context.NumOutputs() + | |||
task_context.NumWorkspaces() | |||
- output_indices_to_skip_.size(); | |||
size_t expected_arg_count = task_context.NumInputs() + task_context.NumOutputs() + task_context.NumWorkspaces() - | |||
output_indices_to_skip_.size(); | |||
if (tiling_buffer_ != nullptr) { | |||
++expected_arg_count; | |||
} | |||
@@ -37,7 +37,7 @@ const std::map<std::string, std::vector<uint32_t>> | |||
{BROADCASTGRADIENTARGS, {}} | |||
}; | |||
const std::set<std::string> DependInputShapeTask::depend_input_shape_ops_ = {SHAPE, SHAPEN, RANK, SIZE, NOOP}; | |||
const std::set<std::string> DependInputShapeTask::depend_input_shape_ops_ = {SHAPE, SHAPEN, RANK, SIZE}; | |||
Status RefInputTask::UpdateArgs(TaskContext &) { | |||
// no need update args | |||
@@ -252,9 +252,16 @@ Status GeLocalNodeExecutor::LoadTask(const HybridModel &model, | |||
GELOGE(INTERNAL_ERROR, "[Get][Tensor] failed for name: %s", node->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
task = MakeShared<ConstantNodeTask>(tensor); | |||
GE_CHECK_NOTNULL(task); | |||
} else if (node_type == NOOP) { | |||
GELOGI("node %s type %s , use NoOpNodeTask.", node->GetName().c_str(), node_type.c_str()); | |||
task = MakeShared<NoOpNodeTask>(); | |||
if (task == nullptr) { | |||
REPORT_CALL_ERROR("E19999", "Create NoOpNodeTask failed for NoOp node %s.", node->GetName().c_str()); | |||
GELOGE(MEMALLOC_FAILED, "[Create][NoOpNodeTask]failed for NoOp node %s.", node->GetName().c_str()); | |||
return MEMALLOC_FAILED; | |||
} | |||
} else { | |||
GELOGE(UNSUPPORTED, "node %s type %s is not support in GeLocalNodeExecutor now.", | |||
node->GetName().c_str(), node_type.c_str()); | |||
@@ -280,5 +287,17 @@ Status ConstantNodeTask::ExecuteAsync(TaskContext &context, std::function<void() | |||
GELOGD("[%s] Done execute successfully.", context.GetNodeName()); | |||
return SUCCESS; | |||
} | |||
Status NoOpNodeTask::UpdateArgs(TaskContext &context) { | |||
// no need to update args | |||
return SUCCESS; | |||
} | |||
Status NoOpNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { | |||
GELOGD("[%s] Start execute.", context.GetNodeName()); | |||
GE_CHK_STATUS_RET(context.TryExecuteCallback(done_callback)); | |||
GELOGD("[%s] Done execute successfully.", context.GetNodeName()); | |||
return SUCCESS; | |||
} | |||
} // namespace hybrid | |||
} // namespace ge |
@@ -80,6 +80,14 @@ class ConstantNodeTask : public NodeTask { | |||
const TensorValue *tensor_; | |||
}; | |||
class NoOpNodeTask : public NodeTask { | |||
public: | |||
explicit NoOpNodeTask() = default; | |||
~NoOpNodeTask() = default; | |||
Status UpdateArgs(TaskContext &context) override; | |||
Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override; | |||
}; | |||
class GeLocalNodeExecutor : public NodeExecutor { | |||
public: | |||
@@ -20,6 +20,7 @@ | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/utils/tensor_utils.h" | |||
#include "graph/utils/type_utils.h" | |||
#include "graph/utils/node_utils.h" | |||
#include "common/ge/ge_util.h" | |||
#include "common/op/ge_op_utils.h" | |||
@@ -201,6 +202,13 @@ Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::functio | |||
GE_CHECK_NOTNULL(in_x); | |||
GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(0, *in_x)); // y | |||
const auto &node_state = task_context.GetNodeState(); | |||
if (kNextIterationOpTypes.count(node_state->GetType()) > 0) { | |||
node_state->RunLoopNext(); | |||
} else if (kExitOpTypes.count(node_state->GetType()) > 0) { | |||
node_state->RunLoopExit(); | |||
} | |||
if (done_callback) { | |||
GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); | |||
} | |||
@@ -61,6 +61,6 @@ class RtsTaskFactory { | |||
REGISTER_RTS_TASK_CREATOR_UNIQ_HELPER(__COUNTER__, task_type, task_clazz) | |||
#define REGISTER_RTS_TASK_CREATOR_UNIQ_HELPER(ctr, type, clazz) \ | |||
RtsTaskFactory::RtsTaskRegistrar g_##type##_Creator##ctr(type, []()-> RtsNodeTaskPtr { return MakeShared<clazz>(); }) | |||
RtsTaskFactory::RtsTaskRegistrar g_##type##_Creator##ctr(type, []()->RtsNodeTaskPtr { return MakeShared<clazz>(); }) | |||
#endif // GE_HYBRID_NODE_EXECUTOR_RTS_TASK_FACTORY_H_ |
@@ -489,6 +489,11 @@ void TaskContext::ReleaseInputsAndOutputs() { | |||
} | |||
void TaskContext::ReleaseInput(int index) { | |||
if (node_item_->enter_inside_.count(index) > 0) { | |||
GELOGD("[%s] Tensor of input[%d] is enter, keep it", GetNodeName(), index); | |||
return; | |||
} | |||
auto input_tensor = MutableInput(index); | |||
if (input_tensor != nullptr) { | |||
input_tensor->Destroy(); | |||
@@ -37,6 +37,9 @@ const size_t kMaxNDDimNum = 4; | |||
const size_t kMinNDDimNum = 1; | |||
const size_t kSquareBracketsSize = 2; | |||
const size_t kRangePairSize = 2; | |||
const size_t kShapeRangeSize = 2; | |||
const size_t kShapeRangeStrIndex = 2; | |||
const size_t kShapeRangeStrSize = 3; | |||
// datatype/formats from user to GE, Unified to util interface file later | |||
const std::map<std::string, ge::DataType> kOutputTypeSupportDatatype = { | |||
{"FP32", ge::DT_FLOAT}, {"FP16", ge::DT_FLOAT16}, {"UINT8", ge::DT_UINT8}}; | |||
@@ -434,7 +437,7 @@ Status ParseInputShapeRange(const std::string &shape_range, | |||
std::vector<std::vector<std::pair<int64_t, int64_t>>> &range) { | |||
GELOGD("Input shape range %s", shape_range.c_str()); | |||
if (shape_range.size() < 2) { | |||
if (shape_range.size() < kShapeRangeSize) { | |||
REPORT_INPUT_ERROR("E10048", std::vector<std::string>({"shape_range", "reason", "sample"}), | |||
std::vector<std::string>({shape_range, kInputShapeRangeSizeInvalid, kInputShapeRangeSample4})); | |||
GELOGE(PARAM_INVALID, "[Parse][ShapeRange] str:%s invalid, reason: %s, correct sample is %s.", | |||
@@ -451,7 +454,7 @@ Status ParseInputShapeRange(const std::string &shape_range, | |||
return PARAM_INVALID; | |||
} | |||
for (auto &shape_range_str : shape_range_set) { | |||
if (shape_range_str.size() < 3) { | |||
if (shape_range_str.size() < kShapeRangeStrSize) { | |||
// shape_range_str should be "[2~3,1" | |||
// or ",[2~3,1". because we should trim '[' or ',[' | |||
// so shape_range_str.size() < 3 is invalid | |||
@@ -462,7 +465,7 @@ Status ParseInputShapeRange(const std::string &shape_range, | |||
shape_range_str = shape_range_str.substr(1, shape_range_str.size()); | |||
} | |||
if (ge::StringUtils::StartWith(shape_range_str, ",")) { | |||
shape_range_str = shape_range_str.substr(2, shape_range_str.size()); | |||
shape_range_str = shape_range_str.substr(kShapeRangeStrIndex, shape_range_str.size()); | |||
} | |||
// parse shape_range of single input. eg. "1~20,3,3~6,-1" | |||
@@ -0,0 +1,25 @@ | |||
#!/bin/bash | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
set -e | |||
export PROJECT_HOME=${PROJECT_HOME:-$(dirname "$0")/../} | |||
function main(){ | |||
${PROJECT_HOME}/build.sh "$@" | |||
} | |||
main "$@" | |||
set +e |
@@ -1,5 +1,5 @@ | |||
#!/bin/bash | |||
# Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
@@ -13,7 +13,6 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
set -e | |||
CLANG_FORMAT=$(which clang-format) || (echo "Please install 'clang-format' tool first"; exit 1) | |||
@@ -25,10 +24,10 @@ if [[ "${version}" -lt "8" ]]; then | |||
fi | |||
CURRENT_PATH=$(pwd) | |||
SCRIPTS_PATH=$(dirname "$0") | |||
PROJECT_HOME=${PROJECT_HOME:-$(dirname "$0")/../} | |||
echo "CURRENT_PATH=$CURRENT_PATH" | |||
echo "SCRIPTS_PATH=$SCRIPTS_PATH" | |||
echo "PROJECT_HOME=$PROJECT_HOME" | |||
# print usage message | |||
function usage() | |||
@@ -81,45 +80,46 @@ function checkopts() | |||
checkopts "$@" | |||
# switch to project root path, which contains clang-format config file '.clang-format' | |||
cd "${SCRIPTS_PATH}/.." || exit 1 | |||
pushd "${CURRENT_PATH}" | |||
CHECK_LIST_FILE='__checked_files_list__' | |||
cd "${PROJECT_HOME}" || exit 1 | |||
CHECK_LIST_FILE='__checked_files_list__' | |||
if [ "X${mode}" == "Xall" ]; then | |||
find src -type f -name "*" | grep "\.h$\|\.cc$" > "${CHECK_LIST_FILE}" || true | |||
find inc -type f -name "*" | grep "\.h$\|\.cc$" >> "${CHECK_LIST_FILE}" || true | |||
elif [ "X${mode}" == "Xchanged" ]; then | |||
# --diff-filter=ACMRTUXB will ignore deleted files in commit | |||
git diff --diff-filter=ACMRTUXB --name-only | grep "^inc\|^src" | grep "\.h$\|\.cc$" > "${CHECK_LIST_FILE}" || true | |||
else # "X${mode}" == "Xlastcommit" | |||
git diff --diff-filter=ACMRTUXB --name-only HEAD~ HEAD | grep "^inc\|^src" | grep "\.h$\|\.cc$" > "${CHECK_LIST_FILE}" || true | |||
fi | |||
if [ "X${mode}" == "Xall" ]; then | |||
find src -type f -name "*" | grep "\.h$\|\.cc$" > "${CHECK_LIST_FILE}" || true | |||
find inc -type f -name "*" | grep "\.h$\|\.cc$" >> "${CHECK_LIST_FILE}" || true | |||
elif [ "X${mode}" == "Xchanged" ]; then | |||
# --diff-filter=ACMRTUXB will ignore deleted files in commit | |||
git diff --diff-filter=ACMRTUXB --name-only | grep "^inc\|^src" | grep "\.h$\|\.cc$" > "${CHECK_LIST_FILE}" || true | |||
else # "X${mode}" == "Xlastcommit" | |||
git diff --diff-filter=ACMRTUXB --name-only HEAD~ HEAD | grep "^inc\|^src" | grep "\.h$\|\.cc$" > "${CHECK_LIST_FILE}" || true | |||
fi | |||
CHECK_RESULT_FILE=__code_format_check_result__ | |||
echo "0" > "$CHECK_RESULT_FILE" | |||
CHECK_RESULT_FILE=__code_format_check_result__ | |||
echo "0" > "$CHECK_RESULT_FILE" | |||
# check format of files modified in the lastest commit | |||
while read line; do | |||
BASE_NAME=$(basename "${line}") | |||
TEMP_FILE="__TEMP__${BASE_NAME}" | |||
cp "${line}" "${TEMP_FILE}" | |||
${CLANG_FORMAT} -i "${TEMP_FILE}" | |||
set +e | |||
diff "${TEMP_FILE}" "${line}" | |||
ret=$? | |||
set -e | |||
rm "${TEMP_FILE}" | |||
if [[ "${ret}" -ne 0 ]]; then | |||
echo "File ${line} is not formated, please format it." | |||
echo "1" > "${CHECK_RESULT_FILE}" | |||
break | |||
fi | |||
done < "${CHECK_LIST_FILE}" | |||
# check format of files modified in the lastest commit | |||
while read line; do | |||
BASE_NAME=$(basename "${line}") | |||
TEMP_FILE="__TEMP__${BASE_NAME}" | |||
cp "${line}" "${TEMP_FILE}" | |||
${CLANG_FORMAT} -i "${TEMP_FILE}" | |||
set +e | |||
diff "${TEMP_FILE}" "${line}" | |||
ret=$? | |||
set -e | |||
rm "${TEMP_FILE}" | |||
if [[ "${ret}" -ne 0 ]]; then | |||
echo "File ${line} is not formated, please format it." | |||
echo "1" > "${CHECK_RESULT_FILE}" | |||
break | |||
fi | |||
done < "${CHECK_LIST_FILE}" | |||
result=$(cat "${CHECK_RESULT_FILE}") | |||
rm "${CHECK_RESULT_FILE}" | |||
rm "${CHECK_LIST_FILE}" | |||
popd | |||
result=$(cat "${CHECK_RESULT_FILE}") | |||
rm "${CHECK_RESULT_FILE}" | |||
rm "${CHECK_LIST_FILE}" | |||
cd "${CURRENT_PATH}" || exit 1 | |||
if [[ "X${result}" == "X0" ]]; then | |||
echo "Check PASS: specified files are well formated!" | |||
fi | |||
@@ -0,0 +1,90 @@ | |||
#!/bin/bash | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
set -e | |||
PROJECT_HOME=${PROJECT_HOME:-$(dirname "$0")/../../} | |||
function help(){ | |||
cat <<-EOF | |||
Usage: ge clean [OPTIONS] | |||
Options: | |||
-b, --build Clean build dir | |||
-d, --docs Clean generate docs | |||
-i, --install Clean dependenices | |||
-a, --all Clean all | |||
-h, --help | |||
EOF | |||
} | |||
function clean_relative_dir(){ | |||
rm -rf "${PROJECT_HOME}/${1:-output}" | |||
} | |||
function parse_args(){ | |||
parsed_args=$(getopt -a -o bdiah --long build,docs,install,all,help -- "$@") || { | |||
help | |||
exit 1 | |||
} | |||
if [ $# -lt 1 ]; then | |||
clean_relative_dir "build" | |||
clean_relative_dir "output" | |||
exit 1 | |||
fi | |||
eval set -- "$parsed_args" | |||
while true; do | |||
case "$1" in | |||
-b | --build) | |||
clean_relative_dir "build" | |||
clean_relative_dir "output" | |||
;; | |||
-d | --docs) | |||
clean_relative_dir "docs/doxygen" | |||
;; | |||
-i | --install) | |||
clean_relative_dir "deps" | |||
;; | |||
-a | --all) | |||
clean_relative_dir "deps" | |||
clean_relative_dir "build" | |||
clean_relative_dir "output" | |||
clean_relative_dir "docs/doxygen" | |||
;; | |||
-h | --help) | |||
help | |||
;; | |||
--) | |||
shift; break; | |||
;; | |||
*) | |||
help; exit 1 | |||
;; | |||
esac | |||
shift | |||
done | |||
} | |||
function main(){ | |||
parse_args "$@" | |||
} | |||
main "$@" | |||
set +e |
@@ -0,0 +1,115 @@ | |||
#!/bin/bash | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
set -e | |||
PROJECT_HOME=${PROJECT_HOME:-$(dirname "$0")/../../} | |||
PROJECT_HOME=$(cd $PROJECT_HOME || return; pwd) | |||
function help(){ | |||
cat <<-EOF | |||
Usage: ge config [OPTIONS] | |||
update server config for ge, you need input all config info (ip, user, password) | |||
Options: | |||
-i, --ip Config ip config | |||
-u, --user Config user name | |||
-p, --password Config password | |||
-h, --help | |||
Example: ge config -i=121.36.**.** -u=asc**, -p=Asc***\#@\!\$ (Need add escape character \ before special charactor $、#、!) | |||
EOF | |||
} | |||
function write_config_file(){ | |||
local IP=$1 | |||
local USER=$2 | |||
local PASSWORD=$3 | |||
if [[ -z "$IP" ]] || [[ -z "$USER" ]] || [[ -z "$USER" ]]; then | |||
echo "You need input all info (ip, user,password)obout server config !!!" | |||
help | |||
exit 1 | |||
fi | |||
local PASSWORD=${PASSWORD//!/\\!} | |||
local PASSWORD=${PASSWORD//#/\\#} | |||
local PASSWORD=${PASSWORD/\$/\\\$} | |||
local SERVER_CONFIG_FILE=${PROJECT_HOME}/scripts/config/server_config.sh | |||
[ -n "${SERVER_CONFIG_FILE}" ] && rm -rf "${SERVER_CONFIG_FILE}" | |||
cat>${SERVER_CONFIG_FILE}<<-EOF | |||
SERVER_PATH=http://${IP}/package/etrans | |||
DEP_USER=${USER} | |||
DEP_PASSWORD=${PASSWORD} | |||
EOF | |||
} | |||
function parse_args(){ | |||
parsed_args=$(getopt -a -o i::u::p::h --long ip::,user::,password::,help -- "$@") || { | |||
help | |||
exit 1 | |||
} | |||
if [ $# -lt 1 ]; then | |||
help | |||
exit 1 | |||
fi | |||
local IP= | |||
local USER= | |||
local PASSWORD= | |||
eval set -- "$parsed_args" | |||
while true; do | |||
case "$1" in | |||
-i | --ip) | |||
IP=$2 | |||
;; | |||
-u | --user) | |||
USER=$2 | |||
;; | |||
-p | --password) | |||
PASSWORD=$2 | |||
;; | |||
-h | --help) | |||
help; exit; | |||
;; | |||
--) | |||
shift; break; | |||
;; | |||
*) | |||
help; exit 1 | |||
;; | |||
esac | |||
shift 2 | |||
done | |||
write_config_file $IP $USER $PASSWORD | |||
} | |||
function main(){ | |||
parse_args "$@" | |||
} | |||
main "$@" | |||
set +e |
@@ -0,0 +1,136 @@ | |||
#!/bin/bash | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
set -e | |||
function help(){ | |||
cat <<-EOF | |||
Usage: ge cov [OPTIONS] | |||
Options: | |||
-a, --all Full coverage | |||
-i, --increment Increment coverage | |||
-d, --directory Coverage of directory | |||
-h, --help | |||
EOF | |||
} | |||
PROJECT_HOME=${PROJECT_HOME:-$(dirname "$0")/../../} | |||
PROJECT_HOME=$(cd $PROJECT_HOME || return; pwd) | |||
ALL_COV_GEN_PATH=${PROJECT_HOME}/cov/all | |||
DIFF_FILE_PATH=${PROJECT_HOME}/cov/diff | |||
DIFF_FILE_NAME=${DIFF_FILE_PATH}/inc_change_diff.txt | |||
function process_diff_format(){ | |||
sed -i "s/--- a/--- \/code\/Turing\/graphEngine/g" ${DIFF_FILE_NAME} | |||
sed -i "s/+++ b/+++ \/code\/Turing\/graphEngine/g" ${DIFF_FILE_NAME} | |||
} | |||
function add_cov_generate(){ | |||
addlcov --diff ${ALL_COV_GEN_PATH}/coverage.info ${DIFF_FILE_NAME} -o ${PROJECT_HOME}/cov/diff/inc_coverage.info | |||
} | |||
function gen_add_cov_html(){ | |||
genhtml --prefix ${PROJECT_HOME} -o ${PROJECT_HOME}/cov/diff/html ${PROJECT_HOME}/cov/diff/inc_coverage.info --legend -t CHG --no-branch-coverage --no-function-coverage | |||
} | |||
function increment_cov_for_directory(){ | |||
[ -n "${DIFF_FILE_PATH}" ] && rm -rf "${DIFF_FILE_PATH}" | |||
mkdir -p ${DIFF_FILE_PATH} | |||
git diff HEAD -- $1 >>${DIFF_FILE_NAME} | |||
process_diff_format | |||
add_cov_generate | |||
gen_add_cov_html | |||
} | |||
function run_all_coverage(){ | |||
[ -n "${ALL_COV_GEN_PATH}" ] && rm -rf ${ALL_COV_GEN_PATH} | |||
mkdir -p ${ALL_COV_GEN_PATH} | |||
pushd "${PWD}" >/dev/null | |||
cd ${PROJECT_HOME} | |||
lcov -c -d build/tests/ut/ge -d build/tests/ut/common/graph/ -o ${ALL_COV_GEN_PATH}/tmp.info | |||
lcov -r ${ALL_COV_GEN_PATH}/tmp.info '*/output/*' '*/build/opensrc/*' '*/build/proto/*' '*/third_party/*' '*/tests/*' '/usr/local/*' '/usr/include/*' '*/metadef/*' '*/parser/*' -o ${ALL_COV_GEN_PATH}/coverage.info | |||
cd ${ALL_COV_GEN_PATH} | |||
genhtml coverage.info | |||
popd >/dev/null | |||
} | |||
function do_coverage_run(){ | |||
local cov_mode=$1 | |||
local directory_dir=$2 | |||
run_all_coverage | |||
if [ "$cov_mode" = "all" ]; then | |||
exit 1 | |||
elif [ -n "$directory_dir" ]; then | |||
increment_cov_for_directory $directory_dir | |||
else | |||
increment_cov_for_directory "ge" | |||
fi | |||
} | |||
function parse_args(){ | |||
parsed_args=$(getopt -a -o aid::h --long all,increment,directory::,help -- "$@") || { | |||
help | |||
exit 1 | |||
} | |||
if [ $# -lt 1 ]; then | |||
run_all_coverage | |||
exit 1 | |||
fi | |||
local cov_mode="increment" | |||
local directory_dir= | |||
eval set -- "$parsed_args" | |||
while true; do | |||
case "$1" in | |||
-a | --all) | |||
cov_mode="all" | |||
;; | |||
-i | --increment) | |||
;; | |||
-d | --directory) | |||
directory_dir=$2 | |||
shift | |||
;; | |||
-h | --help) | |||
help; exit 1; | |||
;; | |||
--) | |||
shift; break; | |||
;; | |||
*) | |||
help; exit 1; | |||
;; | |||
esac | |||
shift | |||
done | |||
do_coverage_run $cov_mode $directory_dir | |||
} | |||
function main(){ | |||
parse_args "$@" | |||
} | |||
main "$@" | |||
set +e |
@@ -0,0 +1,87 @@ | |||
#!/bin/bash | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
set -e | |||
function help(){ | |||
cat <<-EOF | |||
Usage: ge docs [OPTIONS] | |||
Options: | |||
-b, --brief Build brief docs | |||
-a, --all Build all docs | |||
-h, --help | |||
EOF | |||
} | |||
PROJECT_HOME=${PROJECT_HOME:-$(dirname "$0")/../../} | |||
PROJECT_HOME=$(cd $PROJECT_HOME || return; pwd) | |||
BRIEF_DOXYFILE_PATH=${PROJECT_HOME}/scripts/docs/Doxyfile_brief | |||
ALL_DOXYFILE_PATH=${PROJECT_HOME}/scripts/docs/Doxyfile_all | |||
function build_brief_docs(){ | |||
rm -rf "${PROJECT_HOME}/docs/doxygen" | |||
doxygen ${BRIEF_DOXYFILE_PATH} | |||
} | |||
function build_all_docs(){ | |||
rm -rf "${PROJECT_HOME}/docs/doxygen" | |||
doxygen ${ALL_DOXYFILE_PATH} | |||
} | |||
function parse_args(){ | |||
parsed_args=$(getopt -a -o bah --long brief,all,help -- "$@") || { | |||
help | |||
exit 1 | |||
} | |||
if [ $# -lt 1 ]; then | |||
build_all_docs | |||
exit 1 | |||
fi | |||
eval set -- "$parsed_args" | |||
while true; do | |||
case "$1" in | |||
-b | --brief) | |||
build_brief_docs | |||
;; | |||
-a | --all) | |||
build_all_docs | |||
;; | |||
-h | --help) | |||
help; exit 1; | |||
;; | |||
--) | |||
shift; break; | |||
;; | |||
*) | |||
help; exit 1; | |||
;; | |||
esac | |||
shift | |||
done | |||
} | |||
function main(){ | |||
parse_args "$@" | |||
} | |||
main "$@" | |||
set +e |
@@ -0,0 +1,42 @@ | |||
# this dockerfile used for graphengine build | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
FROM ubuntu:18.04 | |||
RUN apt-get update \ | |||
&& apt-get install -y git g++ wget unzip clang-format-9 build-essential lcov vim | |||
# install for doxygen | |||
RUN apt-get install -y graphviz doxygen | |||
# install for graph ensy engine | |||
RUN cpan install -y Graph::Easy | |||
RUN wget https://cmake.org/files/v3.16/cmake-3.16.7-Linux-x86_64.tar.gz | |||
RUN mkdir -p /opt/cmake-3.16.7 \ | |||
&& tar -xvf cmake-3.16.7-Linux-x86_64.tar.gz -C /opt/cmake-3.16.7 --strip-components=1 \ | |||
&& ln -sf /opt/cmake-3.16.7/bin/* /usr/bin/ \ | |||
&& mv /usr/bin/clang-format-9 /usr/bin/clang-format | |||
RUN wget https://github.com/ccup/lcov/archive/refs/tags/add_lcov.tar.gz -O add_lcov.tar.gz \ | |||
&& mkdir -p /opt/addlcov1.0.0 \ | |||
&& tar -xvf add_lcov.tar.gz -C /opt/addlcov1.0.0 \ | |||
&& mv /opt/addlcov1.0.0/lcov-add_lcov/bin/lcov /usr/bin/addlcov | |||
ENV PROJECT_HOME=/code/Turing/graphEngine | |||
RUN echo "alias ge=/code/Turing/graphEngine/scripts/ge.sh">>~/.bashrc | |||
@@ -0,0 +1,146 @@ | |||
#!/bin/bash | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
set -e | |||
PROJECT_HOME=${PROJECT_HOME:-$(dirname "$0")/../../} | |||
MOUNT_PROJECT_HOME=$(cd $PROJECT_HOME || return; pwd) | |||
DOCKER_BUILD_ENV_NAME=${MOUNT_PROJECT_HOME#*/} | |||
DOCKER_BUILD_ENV_NAME=${DOCKER_BUILD_ENV_NAME//\//\_} | |||
DOCKER_IMAGE_TAG=ge_build_env.1.0.6 | |||
DOCKER_IAMGE_NAME=joycode2art/turing | |||
DOCKER_FULL_IMAGE_NAME=${DOCKER_IAMGE_NAME}:${DOCKER_IMAGE_TAG} | |||
if [ "$(uname)" == "Darwin" ]; then | |||
#running on Mac OS | |||
docker_cmd=docker | |||
MOUNT_PROJECT_HOME=${MOUNT_PROJECT_HOME} | |||
docker_work_dir=/code/Turing/graphEngine | |||
docker_bash_dir=/bin/bash | |||
elif [ "$(expr substr "$(uname -s)" 1 10)" == "MINGW32_NT" ] || [ "$(expr substr "$(uname -s)" 1 10)" == "MINGW64_NT" ]; then | |||
#running on Windows | |||
docker_cmd="winpty docker" | |||
MOUNT_PROJECT_HOME=/${MOUNT_PROJECT_HOME} | |||
docker_work_dir=//code/Turing/graphEngine | |||
docker_bash_dir=//bin/bash | |||
elif [ "$(expr substr "$(uname -s)" 1 5)" == "Linux" ]; then | |||
#running on Linux | |||
docker_cmd=docker | |||
MOUNT_PROJECT_HOME=${PROJECT_HOME} | |||
docker_work_dir=/code/Turing/graphEngine | |||
docker_bash_dir=/bin/bash | |||
fi | |||
function build_docker_image(){ | |||
if test -z "$(docker images |grep ${DOCKER_IAMGE_NAME} | grep ${DOCKER_IMAGE_TAG})"; then | |||
$docker_cmd build -t ${DOCKER_FULL_IMAGE_NAME} ${PROJECT_HOME}/scripts/env | |||
else | |||
echo "docker image for graph engine build is build ok...." | |||
fi | |||
} | |||
function pull_docker_image(){ | |||
$docker_cmd pull $DOCKER_FULL_IMAGE_NAME | |||
} | |||
function enter_docker_env(){ | |||
if test -z "$(docker images |grep ${DOCKER_IAMGE_NAME} | grep ${DOCKER_IMAGE_TAG})"; then | |||
echo "please run 'ge env --pull' to download images first!" | |||
elif test -z "$(docker ps -a |grep ${DOCKER_BUILD_ENV_NAME})"; then | |||
$docker_cmd run -it -v ${MOUNT_PROJECT_HOME}:/code/Turing/graphEngine --workdir ${docker_work_dir} --name ${DOCKER_BUILD_ENV_NAME} ${DOCKER_FULL_IMAGE_NAME} ${docker_bash_dir} | |||
elif test -z "$(docker ps |grep ${DOCKER_BUILD_ENV_NAME})"; then | |||
$docker_cmd start ${DOCKER_BUILD_ENV_NAME} | |||
$docker_cmd exec -w ${docker_work_dir} -it ${DOCKER_BUILD_ENV_NAME} ${docker_bash_dir} | |||
else | |||
$docker_cmd exec -w ${docker_work_dir} -it ${DOCKER_BUILD_ENV_NAME} ${docker_bash_dir} | |||
fi | |||
} | |||
function resert_docker_env(){ | |||
if test -z "$(docker ps -a |grep ${DOCKER_BUILD_ENV_NAME})"; then | |||
echo "no runing container for graphengine build" | |||
elif test -z "$(docker ps |grep ${DOCKER_BUILD_ENV_NAME})"; then | |||
$docker_cmd rm ${DOCKER_BUILD_ENV_NAME} | |||
else | |||
$docker_cmd stop ${DOCKER_BUILD_ENV_NAME} | |||
$docker_cmd rm ${DOCKER_BUILD_ENV_NAME} | |||
fi | |||
} | |||
function help(){ | |||
cat <<-EOF | |||
Usage: ge env [OPTIONS] | |||
Prepare for docker env for build and test | |||
Options: | |||
-b, --build Build docker image | |||
-p, --pull Pull docker image | |||
-e, --enter Enter container | |||
-r, --reset Reset container | |||
-h, --help | |||
EOF | |||
} | |||
function parse_args(){ | |||
parsed_args=$(getopt -a -o bperh --long build,pull,enter,resethelp -- "$@") || { | |||
help | |||
exit 1 | |||
} | |||
if [ $# -lt 1 ]; then | |||
pull_docker_image | |||
enter_docker_env | |||
exit 1 | |||
fi | |||
eval set -- "$parsed_args" | |||
while true; do | |||
case "$1" in | |||
-b | --build) | |||
build_docker_image | |||
;; | |||
-p | --pull) | |||
pull_docker_image | |||
;; | |||
-e | --enter) | |||
enter_docker_env | |||
;; | |||
-r | --reset) | |||
resert_docker_env | |||
;; | |||
-h | --help) | |||
help | |||
;; | |||
--) | |||
shift; break; | |||
;; | |||
*) | |||
help; exit 1 | |||
;; | |||
esac | |||
shift | |||
done | |||
} | |||
function main(){ | |||
parse_args "$@" | |||
} | |||
main "$@" | |||
set -e |
@@ -1,5 +1,5 @@ | |||
#!/bin/bash | |||
# Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
@@ -24,11 +24,12 @@ if [[ "${version}" -lt "8" ]]; then | |||
exit 1 | |||
fi | |||
CURRENT_PATH=$(pwd) | |||
SCRIPTS_PATH=$(dirname "$0") | |||
PROJECT_HOME=${PROJECT_HOME:-$(dirname "$0")/../../} | |||
echo "CURRENT_PATH=${CURRENT_PATH}" | |||
echo "SCRIPTS_PATH=${SCRIPTS_PATH}" | |||
echo "PROJECT_HOME=${PROJECT_HOME}" | |||
# print usage message | |||
function usage() | |||
@@ -81,27 +82,28 @@ function checkopts() | |||
checkopts "$@" | |||
# switch to project root path, which contains clang-format config file '.clang-format' | |||
cd "${SCRIPTS_PATH}/.." || exit 1 | |||
FMT_FILE_LIST='__format_files_list__' | |||
if [[ "X${mode}" == "Xall" ]]; then | |||
find src -type f -name "*" | grep "\.h$\|\.cc$" > "${FMT_FILE_LIST}" || true | |||
find inc -type f -name "*" | grep "\.h$\|\.cc$" >> "${FMT_FILE_LIST}" || true | |||
elif [[ "X${mode}" == "Xchanged" ]]; then | |||
# --diff-filter=ACMRTUXB will ignore deleted files in commit | |||
git diff --diff-filter=ACMRTUXB --name-only | grep "^inc\|^src" | grep "\.h$\|\.cc$" >> "${FMT_FILE_LIST}" || true | |||
else # "X${mode}" == "Xlastcommit" | |||
git diff --diff-filter=ACMRTUXB --name-only HEAD~ HEAD | grep "^inc\|^src" | grep "\.h$\|\.cc$" > "${FMT_FILE_LIST}" || true | |||
fi | |||
while read line; do | |||
if [ -f "${line}" ]; then | |||
${CLANG_FORMAT} -i "${line}" | |||
fi | |||
done < "${FMT_FILE_LIST}" | |||
pushd "${CURRENT_PATH}" | |||
cd "${PROJECT_HOME}" || exit 1 | |||
FMT_FILE_LIST='__format_files_list__' | |||
if [[ "X${mode}" == "Xall" ]]; then | |||
find src -type f -name "*" | grep "\.h$\|\.cc$" > "${FMT_FILE_LIST}" || true | |||
find inc -type f -name "*" | grep "\.h$\|\.cc$" >> "${FMT_FILE_LIST}" || true | |||
elif [[ "X${mode}" == "Xchanged" ]]; then | |||
# --diff-filter=ACMRTUXB will ignore deleted files in commit | |||
git diff --diff-filter=ACMRTUXB --name-only | grep "^inc\|^src" | grep "\.h$\|\.cc$" >> "${FMT_FILE_LIST}" || true | |||
else # "X${mode}" == "Xlastcommit" | |||
git diff --diff-filter=ACMRTUXB --name-only HEAD~ HEAD | grep "^inc\|^src" | grep "\.h$\|\.cc$" > "${FMT_FILE_LIST}" || true | |||
fi | |||
while read line; do | |||
if [ -f "${line}" ]; then | |||
${CLANG_FORMAT} -i "${line}" | |||
fi | |||
done < "${FMT_FILE_LIST}" | |||
rm "${FMT_FILE_LIST}" | |||
cd "${CURRENT_PATH}" || exit 1 | |||
rm "${FMT_FILE_LIST}" | |||
popd | |||
echo "Specified cpp source files have been format successfully." |
@@ -0,0 +1,77 @@ | |||
#!/bin/bash | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
GE_BASH_HOME=$(dirname "$0") | |||
export PROJECT_HOME=${PROJECT_HOME:-${GE_BASH_HOME}/../} | |||
PROJECT_HOME=$(cd $PROJECT_HOME || return; pwd) | |||
function help(){ | |||
cat <<-EOF | |||
Usage: ge COMMANDS | |||
Run ge commands | |||
Commands: | |||
env Prepare docker env | |||
config Config dependencies server | |||
update Update dependencies | |||
format Format code | |||
build Build code | |||
test Run test of UT/ST | |||
cov Run Coverage | |||
docs Generate documents | |||
clean Clean | |||
EOF | |||
} | |||
function ge_error() { | |||
echo "Error: $*" >&2 | |||
help | |||
exit 1 | |||
} | |||
function main(){ | |||
if [ $# -eq 0 ]; then | |||
help; exit 0 | |||
fi | |||
local cmd=$1 | |||
local shell_cmd= | |||
shift | |||
case "$cmd" in | |||
-h|--help) | |||
help; exit 0 | |||
;; | |||
build) | |||
shell_cmd=${PROJECT_HOME}/build.sh | |||
;; | |||
*) | |||
shell_cmd=$GE_BASH_HOME/$cmd/ge_$cmd.sh | |||
;; | |||
esac | |||
[ -e $shell_cmd ] || { | |||
ge_error "ge $shell_cmd is not found" | |||
} | |||
$shell_cmd "$@" | |||
} | |||
main "$@" | |||
@@ -0,0 +1,331 @@ | |||
# graph engine 个人开发工具链使用说明 | |||
GE开发者工具链是graph engine中的一套面向个人开发者的自动化脚本工具链。 | |||
目前支持基于容器开发环境准备、构建依赖的自动下载安装和配置、代码格式化、编译、测试、代码覆盖率检查、文档生成等一系列开发者常用功能。 | |||
## 前置准备 | |||
下面是使用GE开发者工具链,需要手动进行的前置准备; | |||
下列是经过验证后的性能最佳推荐配置: | |||
1. 操作系统,以下任选其一: | |||
- 原生的Linux操作系统,如ubuntu; | |||
- Windows操作系统,推荐安装WSL的ubuntu系统,强烈建议登录WSL内直接下载代码,不要挂卷(构建性能差)! | |||
- MAC OS; | |||
2. docker安装: | |||
- docker安装成功,并且相关镜像源已经设置正确,可正常下载外部镜像。 | |||
3. OS支持的命令行工具: 原生Linux or WSL shell; | |||
可用但不推荐的配置: | |||
- 在windows中直接安装docker,采用仿linux bash(Cygwin,minGW等)执行ge工具链; | |||
(使用这种方式也可以执行所有GE工具链的操作,但是因为windows和容器异构内核的文件访问限制会导致构建速度比较慢) | |||
## 快速上手 | |||
GE工具链对应的脚本在scripts下,可以按照下面流程来执行: | |||
1. 进入到scripts目录: | |||
```sh | |||
$ cd ./scripts | |||
``` | |||
2.执行`ge env`自动下载容器环境,并登陆到环境中 | |||
```sh | |||
$ ./ge.sh env | |||
``` | |||
3.配置外部依赖服务器信息 | |||
```sh | |||
ge config -i=121.36.**.** -u=asc**, -p=Asc***\#@\!$ (Need add escape character \ before special charactor $、#、!) | |||
``` | |||
4.下载和安装构建所依赖的外部库 | |||
```sh | |||
$ ge update | |||
``` | |||
(注:进入容器后,`ge`命令已经自动注册进系统,因此容器内不需要写脚本全称) | |||
5.执行测试,默认执行单元测试用例,`ge test`会自动触发构建 | |||
```sh | |||
$ ge test | |||
``` | |||
## 详细用法 | |||
在scripts目录下,运行./ge.sh -h 即可查看到所有的子命令集合。 | |||
```sh | |||
$ ./ge.sh -h | |||
Usage: ge COMMANDS | |||
Run ge commands | |||
Commands: | |||
env Prepare docker env | |||
config Config dependencies server | |||
update Update dependencies | |||
format Format code | |||
lint Static verify | |||
build Build code | |||
test Run test of UT/ST | |||
cov Run Coverage | |||
docs Generate documents | |||
clean Clean | |||
``` | |||
脚本内置的每个子命令,代表一个独立功能;每个子命令还提供了二级参数用于灵活指定执行方式。 | |||
每个子命令可以通过`-h`查看支持的可配参数。 | |||
例如查询`env`子命令的参数,可以使用如下命令: | |||
```sh | |||
$ ./ge.sh env -h | |||
``` | |||
每个子命令在不带参数时,会有一个默认的行为。 | |||
### `ge env` | |||
该命令用于准备构建和测试使用的容器环境,具体包含参数如下: | |||
``` | |||
$ ./ge.sh env -h | |||
Usage: ge env [OPTIONS] | |||
Prepare for docker env for build and test | |||
Options: | |||
-b, --build Build docker image | |||
-p, --pull Pull docker image | |||
-e, --enter Enter container | |||
-r, --reset Reset container | |||
-h, --help | |||
``` | |||
参数详细解释: | |||
- `-b -- build`: 依据“scripts/env/Dockerfile”生成需要运行的容器镜像; | |||
- `-p -- pull` : 从本地配置的容器中央仓拉取需要的的容器镜像; | |||
- `-e -- enter`: 在本地已有容器镜像的前提下,登录容器运行环境; | |||
- `-r -- reset`: 删除本地运行的容器镜像环境; | |||
默认:从中央容器仓拉取对应的容器镜像,运行实例并登陆。 | |||
### `ge config` | |||
配置外部依赖服务器,具体参数如下: | |||
```sh | |||
$ ge config -h | |||
Usage: ge config [OPTIONS] | |||
update server config for ge, you need input all config info (ip, user, password) | |||
Options: | |||
-i, --ip Config ip config | |||
-u, --user Config user name | |||
-p, --password Config password | |||
-h, --help | |||
Example: ge config -i=121.36.**.** -u=asc**, -p=Asc***\#@\!$ (Need add escape character \ before special charactor $、#、!) | |||
``` | |||
参数详细解释: | |||
- `-i, --ip` : 配置依赖库服务器IP地址; | |||
- `-u, --usr` : 配置依赖库服务器用户名; | |||
- `-p, --password` : 配置依赖库地址; | |||
默认:打印帮助信息。 | |||
### `ge update` | |||
安装graph engine构建所需的外部依赖库,具体参数如下: | |||
```sh | |||
$ ge update -h | |||
Usage: ge update [OPTIONS] | |||
update dependencies of build and test | |||
Options: | |||
-d, --download Download dependencies | |||
-i, --install Install dependencies | |||
-c, --clear Clear dependencies | |||
-h, --help | |||
``` | |||
参数详细解释: | |||
- `-d, --download` : 下载构建需要外部依赖库; | |||
- `-i, --install` : 安装外部依赖包到对应位置; | |||
- `-c, --clear` : 清除下载的外部依赖包; | |||
默认:根据"scripts/update/deps_config.sh"的配置下载外部依赖库并安装到对应目录。 | |||
(注:请确保“scripts/update/server_config.sh”中的服务器地址、用户名、密码已经配置) | |||
### `ge format` | |||
使用clang-format进行代码格式化,具体参数如下: | |||
```sh | |||
$ ge format -h | |||
Options: | |||
-a format of all files | |||
-c format of the files changed compared to last commit, default case | |||
-l format of the files changed in last commit | |||
-h Print usage | |||
``` | |||
参数详细解释: | |||
- `-a` : 格式化所有代码; | |||
- `-c` : 只格式化本次修改的代码; | |||
- `-l` : 格式化上次提交的代码; | |||
默认:格式化本次修改代码。 | |||
### `ge lint` | |||
使用clang-format进行代码格式化检查,具体参数如下: | |||
```sh | |||
$ ge lint -h | |||
Options: | |||
-a Check code format of all files, default case | |||
-c Check code format of the files changed compared to last commit | |||
-l Check code format of the files changed in last commit | |||
-h Print usage | |||
``` | |||
参数详细解释: | |||
- `-a` : 检查所有代码格式; | |||
- `-c` : 只检查修改的代码格式; | |||
- `-l` : 检查上次提交的代码格式; | |||
默认:检查本次修改代码格式。 | |||
### `ge build` | |||
执行构建 (注:调用原有build.sh脚本,改造中...); | |||
### `ge test` | |||
构建和运行测试用例,目前可以支持参数如下: | |||
```sh | |||
$ ge test -h | |||
Usage: ge test [OPTIONS] | |||
Options: | |||
-u, --unit Run unit Test | |||
-c, --component Run component Test | |||
-h, --help | |||
``` | |||
参数详细解释: | |||
- `-u, --unit` : 执行单元测试 | |||
- `-c, --component` : 执行组件测试 | |||
默认:执行单元测试。 | |||
### `ge cov` | |||
执行代码覆盖率检查, 支持全量覆盖和增量覆盖的功能,该命令需要已经跑完测试用例,目前可以支持参数如下: | |||
```sh | |||
$ ge cov -h | |||
Usage: ge cov [OPTIONS] | |||
Options: | |||
-a, --all Full coverage | |||
-i, --increment Increment coverage | |||
-d, --directory Coverage of directory | |||
-h, --help | |||
``` | |||
参数详细解释: | |||
- `-a, --all` : 执行全量覆盖率统计; | |||
- `-i, --increment` : 执行增量覆盖率检查,默认是分析未提交修改的代码覆盖率(如果存在新增加的git未跟踪文件,需要先git add 添加进来才可以); | |||
- `-d, --directory` : 代码进行增量覆盖率检查的代码路径,支持传入路径参数; | |||
默认:执行增量覆盖率检查; | |||
下面的命令演示了如何检查ge目录下所有代码的增量覆盖率: | |||
```sh | |||
$ ge cov -d=ge | |||
``` | |||
### `ge docs` | |||
Doxygen文档生成,包含代码逻辑和物理结构和关系,方便阅读和理解代码;目前可以支持参数如下: | |||
```sh | |||
$ ge docs -h | |||
Usage: ge docs [OPTIONS] | |||
Options: | |||
-b, --brief Build brief docs | |||
-a, --all Build all docs | |||
-h, --help | |||
``` | |||
参数详细解释: | |||
- `-b, --brief` : 生成简要文档,忽略部分关系图生成,速度快; | |||
- `-a, --all` : 生成全量文档,包含各种代码关系图,速度相对慢; | |||
默认: 生成全量代码文档。 | |||
### `ge clean` | |||
清除各种下载或生成的中间文件,目前支持的参数如下: | |||
```sh | |||
$ ge clean -h | |||
Usage: ge clean [OPTIONS] | |||
Options: | |||
-b, --build Clean build dir | |||
-d, --docs Clean generate docs | |||
-i, --install Clean dependenices | |||
-a, --all Clean all | |||
-h, --help | |||
``` | |||
参数详细解释: | |||
- `-b, --build` : 清除生成的编译构建临时文件; | |||
- `-d, --docs` : 清除生成的文档临时文件; | |||
- `-i, --install` : 清除安装的依赖文件; | |||
- `-a, --all` : 清除所有下载和生成的临时文件; | |||
默认:清除编译构建产生临时文件。 | |||
## Follow us | |||
工具链的功能还在不断完善中,有问题请提issue,谢谢! |
@@ -0,0 +1,80 @@ | |||
#!/bin/bash | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
set -e | |||
PROJECT_HOME=${PROJECT_HOME:-$(dirname "$0")/../../} | |||
function help(){ | |||
cat <<-EOF | |||
Usage: ge test [OPTIONS] | |||
Options: | |||
-u, --unit Run unit Test | |||
-c, --component Run component Test | |||
-h, --help | |||
EOF | |||
} | |||
function unit_test(){ | |||
${PROJECT_HOME}/build.sh -u | |||
} | |||
function component_test(){ | |||
${PROJECT_HOME}/build.sh -s | |||
} | |||
function parse_args(){ | |||
parsed_args=$(getopt -a -o uch --long unit,component,help -- "$@") || { | |||
help | |||
exit 1 | |||
} | |||
if [ $# -lt 1 ]; then | |||
unit_test | |||
exit 1 | |||
fi | |||
eval set -- "$parsed_args" | |||
while true; do | |||
case "$1" in | |||
-u | --unit) | |||
unit_test | |||
;; | |||
-c | --component) | |||
component_test | |||
;; | |||
-h | --help) | |||
help | |||
;; | |||
--) | |||
shift; break; | |||
;; | |||
*) | |||
help; exit 1 | |||
;; | |||
esac | |||
shift | |||
done | |||
} | |||
function main(){ | |||
parse_args "$@" | |||
} | |||
main "$@" | |||
set +e |
@@ -0,0 +1,47 @@ | |||
#!/bin/bash | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
set -e | |||
SERVER_CONFIG_FILE=${PROJECT_HOME}/scripts/config/server_config.sh | |||
[ -e $SERVER_CONFIG_FILE ] || { | |||
echo "You have not config dependencies account info first !!!!!" | |||
${PROJECT_HOME}/scripts/config/ge_config.sh -h | |||
exit 1; | |||
} | |||
source scripts/config/server_config.sh | |||
CPU_ARCH=ubuntu18.04.x86_64 | |||
DRIVER_VERSION=20.2.0 | |||
CHIP_NAME=A800-9010 | |||
PRODUCT_VERSION=driver_C76_TR5 | |||
DRIVER_NAME=npu-driver | |||
DRIVER_RUN_NAME=${CHIP_NAME}-${DRIVER_NAME}_${DRIVER_VERSION}_ubuntu18.04-x86_64.run | |||
DEV_TOOLS_VERSION=1.78.t10.0.b100 | |||
export ATC_RUN_NAME=Ascend-atc-${DEV_TOOLS_VERSION}-${CPU_ARCH}.run | |||
export ACL_RUN_NAME=Ascend-acllib-${DEV_TOOLS_VERSION}-${CPU_ARCH}.run | |||
export FWKACL_RUN_NAME=Ascend-fwkacllib-${DEV_TOOLS_VERSION}-${CPU_ARCH}.run | |||
DEV_TOOLS_PACKAGE=x86_ubuntu_os_devtoolset_package | |||
export DRIVER_URL=${SERVER_PATH}/${PRODUCT_VERSION}/${DRIVER_RUN_NAME} | |||
export DEV_TOOLS_URL=${SERVER_PATH}/20210428/${DEV_TOOLS_PACKAGE}.zip | |||
set +e |
@@ -0,0 +1,136 @@ | |||
#!/bin/bash | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
set -e | |||
PROJECT_HOME=${PROJECT_HOME:-$(dirname "$0")/../../} | |||
PROJECT_HOME=$(cd $PROJECT_HOME || return; pwd) | |||
DOWNLOAD_PATH=${PROJECT_HOME}/deps | |||
DEP_LIB_DIR=./lib | |||
DEP_TMP_DIR=./tmp | |||
function extract_deps_so() | |||
{ | |||
echo "begin to extract .run file ........." | |||
chmod 777 ./${DRIVER_RUN_NAME} | |||
unzip ${DEV_TOOLS_PACKAGE}.zip | |||
chmod -R 777 ${DEV_TOOLS_PACKAGE} | |||
[ -n "${DEP_TMP_DIR}" ] && rm -rf "${DEP_TMP_DIR}" | |||
./${DRIVER_RUN_NAME} --noexec --extract=${DEP_TMP_DIR}/driver | |||
./${DEV_TOOLS_PACKAGE}/${ATC_RUN_NAME} --noexec --extract=${DEP_TMP_DIR}/atc | |||
./${DEV_TOOLS_PACKAGE}/${ACL_RUN_NAME} --noexec --extract=${DEP_TMP_DIR}/acllib | |||
./${DEV_TOOLS_PACKAGE}/${FWKACL_RUN_NAME} --noexec --extract=${DEP_TMP_DIR}/fwkacllib | |||
} | |||
function copy_so_to_target_dir() | |||
{ | |||
mkdir -p $DEP_LIB_DIR | |||
mv ${DEP_TMP_DIR}/driver/driver $DEP_LIB_DIR/driver | |||
mv ${DEP_TMP_DIR}/atc/atc $DEP_LIB_DIR/atc | |||
mv ${DEP_TMP_DIR}/acllib/acllib $DEP_LIB_DIR/acllib | |||
mv ${DEP_TMP_DIR}/fwkacllib/fwkacllib $DEP_LIB_DIR/fwkacllib | |||
} | |||
function clear_libs() | |||
{ | |||
[ -n "${DOWNLOAD_PATH}" ] && rm -rf "${DOWNLOAD_PATH}" | |||
} | |||
function download_runs() | |||
{ | |||
source scripts/update/deps_config.sh | |||
echo "begin to download .run file ........." | |||
clear_libs | |||
mkdir -p ./ ${DOWNLOAD_PATH} | |||
pushd "${DOWNLOAD_PATH}" >/dev/null | |||
cd ${DOWNLOAD_PATH} | |||
wget --user=${DEP_USER} --password=${DEP_PASSWORD} ${DRIVER_URL} | |||
wget --user=${DEP_USER} --password=${DEP_PASSWORD} ${DEV_TOOLS_URL} | |||
popd >/dev/null | |||
} | |||
function install_deps() | |||
{ | |||
source scripts/update/deps_config.sh | |||
mkdir -p ./ ${DOWNLOAD_PATH} | |||
pushd "${DOWNLOAD_PATH}" >/dev/null | |||
cd ${DOWNLOAD_PATH} | |||
extract_deps_so | |||
copy_so_to_target_dir | |||
popd >/dev/null | |||
} | |||
function help(){ | |||
cat <<-EOF | |||
Usage: ge update [OPTIONS] | |||
update dependencies of build and test | |||
Options: | |||
-d, --download Download dependencies | |||
-i, --install Install dependencies | |||
-c, --clear Clear dependencies | |||
-h, --help | |||
EOF | |||
} | |||
function parse_args(){ | |||
parsed_args=$(getopt -a -o dich --long download,install,clear,help -- "$@") || { | |||
help | |||
exit 1 | |||
} | |||
if [ $# -lt 1 ]; then | |||
download_runs | |||
install_deps | |||
exit 1 | |||
fi | |||
eval set -- "$parsed_args" | |||
while true; do | |||
case "$1" in | |||
-d | --download) | |||
download_runs | |||
;; | |||
-i | --install) | |||
install_deps | |||
;; | |||
-c | --clear) | |||
clear_libs | |||
;; | |||
-h | --help) | |||
help; exit 1; | |||
;; | |||
--) | |||
shift; break; | |||
;; | |||
*) | |||
help; exit 1 | |||
;; | |||
esac | |||
shift | |||
done | |||
} | |||
function main(){ | |||
parse_args "$@" | |||
} | |||
main "$@" | |||
set +e |
@@ -372,6 +372,7 @@ set(COMMON_FORMAT_SRC_FILES | |||
set(GRAPH_OPTIMIZE_COMMON_SRC_FILES | |||
"${GE_CODE_DIR}/ge/graph/optimize/graph_optimize.cc" | |||
"${GE_CODE_DIR}/ge/graph/optimize/summary_optimize.cc" | |||
"${GE_CODE_DIR}/ge/graph/optimize/mem_rw_conflict_optimize.cc" | |||
) | |||
@@ -715,7 +716,10 @@ set(PASS_TEST_FILES | |||
"graph/passes/mark_node_unknown_shape_pass_unittest.cc" | |||
"graph/passes/reshape_recovery_pass_unittest.cc" | |||
"graph/passes/cast_remove_pass_unittest.cc" | |||
"graph/passes/memcpy_addr_async_unittest.cc" | |||
"graph/passes/memcpy_addr_async_unittest.cc" | |||
"graph/passes/hccl_continuous_pass_unittest.cc" | |||
"graph/passes/hccl_memcpy_pass_unittest.cc" | |||
) | |||
set(KERNEL_TEST_FILES | |||
@@ -798,6 +802,8 @@ set(MULTI_PARTS_TEST_FILES | |||
"graph/manager/run_graph_unittest.cc" | |||
"graph/partition/dynamic_shape_partition_unittest.cc" | |||
"graph/manager/graph_manager_unittest.cc" | |||
"graph/optimize/mem_rw_conflict_optimize_unittest.cc" | |||
"graph/optimize/graph_optimize_unittest.cc" | |||
"session/omg_omg_unittest.cc" | |||
"session/ge_api_unittest.cc" | |||
"session/inner_session_unittest.cc" | |||
@@ -832,6 +838,7 @@ set(HYBRID_TEST_FILES | |||
"hybrid/executor/worker/execution_engine_unittest.cc" | |||
"hybrid/model/hybrid_model_builder_unittest.cc" | |||
"hybrid/node_executor/rts/rts_node_task_unittest.cc" | |||
"hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc" | |||
"hybrid/executor/hybrid_model_async_executor_unittest.cc" | |||
"hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" | |||
"hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" | |||
@@ -0,0 +1,239 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include <gtest/gtest.h> | |||
#include <memory> | |||
#include <iostream> | |||
#define protected public | |||
#define private public | |||
#include "graph/optimize/graph_optimize.h" | |||
#include "init/gelib.h" | |||
#include "ge/ge_api.h" | |||
#undef private | |||
#undef protected | |||
using namespace std; | |||
using namespace testing; | |||
using namespace ge; | |||
namespace { | |||
const char *const kVectorCore = "VectorCore"; | |||
const char *const kAicoreEngine = "AIcoreEngine"; | |||
string CreateEngineConfigJson() { | |||
GELOGI("Begin to create engine config json file."); | |||
string base_path = PluginManager::GetPath(); | |||
GELOGI("Base path is %s.", base_path.c_str()); | |||
string dir_path = base_path.substr(0, base_path.rfind('/') + 1) + "plugin/nnengine/ge_config"; | |||
string cmd = "mkdir -p " + dir_path; | |||
system(cmd.c_str()); | |||
string file_path = dir_path + "/engine_conf.json"; | |||
GELOGI("Begin to write into the config file: %s.", file_path.c_str()); | |||
ofstream ofs(file_path, ios::out); | |||
EXPECT_EQ(!ofs, false); | |||
ofs << "{\n" | |||
" \"schedule_units\" : [ {\n" | |||
" \"id\" : \"TS_1\",\n" | |||
" \"name\" : \"1980_hwts\",\n" | |||
" \"ex_attrs\" : \"\",\n" | |||
" \"cal_engines\" : [\n" | |||
" {\"id\" : \"DNN_VM_GE_LOCAL\", \"name\" : \"GE_LOCAL\", \"independent\" : false, \"attch\" : true, \"skip_assign_stream\" : true },\n" | |||
" {\"id\" : \"AIcoreEngine\", \"name\" : \"AICORE\", \"independent\" : false, \"attch\" : false, \"skip_assign_stream\" : false}\n" | |||
" ]\n" | |||
" } ]\n" | |||
"}"; | |||
ofs.close(); | |||
GELOGI("Json config file %s has been written.", file_path.c_str()); | |||
return file_path; | |||
} | |||
void DeleteFile(const string &file_name) { | |||
auto ret = remove(file_name.c_str()); | |||
if (ret == 0) { | |||
GELOGI("Delete file successfully, file:%s.", file_name.c_str()); | |||
} | |||
} | |||
} | |||
class UtestGraphOptimizeTest : public testing::Test { | |||
protected: | |||
void SetUp() { | |||
config_file_ = CreateEngineConfigJson(); | |||
} | |||
void TearDown() { | |||
DeleteFile(config_file_); | |||
} | |||
private: | |||
string config_file_; | |||
}; | |||
class TestGraphOptimizerSuccess : public GraphOptimizer { | |||
public: | |||
~TestGraphOptimizerSuccess() override { Finalize(); } | |||
Status Initialize(const map<string, string> &options) override { return SUCCESS; } | |||
Status Finalize() override { return SUCCESS; } | |||
Status OptimizeGraphPrepare(ComputeGraph& graph) override { return SUCCESS; } | |||
Status OptimizeGraphBeforeBuild(ComputeGraph& graph) override { return SUCCESS; } | |||
Status OptimizeOriginalGraph(ComputeGraph &graph) override { return SUCCESS; } | |||
Status OptimizeOriginalGraphJudgeInsert(ComputeGraph &graph) override { return SUCCESS; } | |||
Status OptimizeFusedGraph(ComputeGraph &graph) override { return SUCCESS; } | |||
Status OptimizeWholeGraph(ComputeGraph &graph) override { return SUCCESS; } | |||
Status GetAttributes(GraphOptimizerAttribute &attrs) const override { | |||
attrs.engineName = "AIcoreEngine"; | |||
attrs.scope = OPTIMIZER_SCOPE::ENGINE; | |||
return SUCCESS; | |||
} | |||
Status OptimizeStreamGraph(ComputeGraph &graph, const RunContext &context) override { return SUCCESS; } | |||
Status OptimizeFusedGraphAfterGraphSlice(ComputeGraph &graph) override { return SUCCESS; } | |||
Status OptimizeAfterStage1(ComputeGraph &graph) override { return SUCCESS; } | |||
}; | |||
class TestGraphOptimizerFail : public GraphOptimizer { | |||
public: | |||
~TestGraphOptimizerFail() override { Finalize(); } | |||
Status Initialize(const map<string, string> &options) override { return SUCCESS; } | |||
Status Finalize() override { return SUCCESS; } | |||
Status OptimizeGraphPrepare(ComputeGraph& graph) override { return FAILED; } | |||
Status OptimizeGraphBeforeBuild(ComputeGraph& graph) override { return FAILED; } | |||
Status OptimizeOriginalGraph(ComputeGraph &graph) override { return FAILED; } | |||
Status OptimizeOriginalGraphJudgeInsert(ComputeGraph &graph) override { return FAILED; } | |||
Status OptimizeFusedGraph(ComputeGraph &graph) override { return FAILED; } | |||
Status OptimizeWholeGraph(ComputeGraph &graph) override { return FAILED; } | |||
Status GetAttributes(GraphOptimizerAttribute &attrs) const override { | |||
attrs.engineName = "AIcoreEngine"; | |||
attrs.scope = OPTIMIZER_SCOPE::ENGINE; | |||
return SUCCESS; | |||
} | |||
Status OptimizeStreamGraph(ComputeGraph &graph, const RunContext &context) override { return FAILED; } | |||
Status OptimizeFusedGraphAfterGraphSlice(ComputeGraph &graph) override { return FAILED; } | |||
Status OptimizeAfterStage1(ComputeGraph &graph) override { return FAILED; } | |||
}; | |||
TEST_F(UtestGraphOptimizeTest, test_OptimizeAfterStage1_succ) { | |||
map<string, string> options; | |||
Status ret = ge::GELib::Initialize(options); | |||
EXPECT_EQ(ret, SUCCESS); | |||
shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||
EXPECT_NE(instance_ptr, nullptr); | |||
GraphOptimizerPtr graph_opt = MakeShared<TestGraphOptimizerSuccess>(); | |||
instance_ptr->opsManager_.graph_optimizers_by_priority_.push_back(make_pair("AIcoreEngine", graph_opt)); | |||
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test_graph"); | |||
GraphOptimize base_optimize; | |||
ret = base_optimize.OptimizeAfterStage1(compute_graph); | |||
EXPECT_EQ(ret, SUCCESS); | |||
base_optimize.core_type_ = kVectorCore; | |||
ret = base_optimize.OptimizeAfterStage1(compute_graph); | |||
EXPECT_EQ(ret, SUCCESS); | |||
ret = instance_ptr->Finalize(); | |||
EXPECT_EQ(ret, SUCCESS); | |||
} | |||
TEST_F(UtestGraphOptimizeTest, test_OptimizeAfterStage1_fail) { | |||
ComputeGraphPtr compute_graph = nullptr; | |||
GraphOptimize base_optimize; | |||
// 1. Input graph is nullptr. | |||
Status ret = base_optimize.OptimizeAfterStage1(compute_graph); | |||
EXPECT_EQ(ret, PARAM_INVALID); | |||
// 2. GELib is not initialized. | |||
compute_graph = MakeShared<ComputeGraph>("test_graph"); | |||
ret = base_optimize.OptimizeAfterStage1(compute_graph); | |||
EXPECT_EQ(ret, GE_CLI_GE_NOT_INITIALIZED); | |||
// 3. The optimizer registered with the engine returned a failure. | |||
map<string, string> options; | |||
ret = ge::GELib::Initialize(options); | |||
EXPECT_EQ(ret, SUCCESS); | |||
shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||
EXPECT_NE(instance_ptr, nullptr); | |||
GraphOptimizerPtr graph_opt = MakeShared<TestGraphOptimizerFail>(); | |||
instance_ptr->opsManager_.graph_optimizers_by_priority_.push_back(make_pair("AIcoreEngine", graph_opt)); | |||
ret = base_optimize.OptimizeAfterStage1(compute_graph); | |||
EXPECT_EQ(ret, FAILED); | |||
ret = instance_ptr->Finalize(); | |||
EXPECT_EQ(ret, SUCCESS); | |||
} | |||
TEST_F(UtestGraphOptimizeTest, test_optimizers_succ) { | |||
map<string, string> options; | |||
Status ret = ge::GELib::Initialize(options); | |||
EXPECT_EQ(ret, SUCCESS); | |||
shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||
EXPECT_NE(instance_ptr, nullptr); | |||
GraphOptimizerPtr graph_opt = MakeShared<TestGraphOptimizerSuccess>(); | |||
instance_ptr->opsManager_.graph_optimizers_by_priority_.push_back(make_pair("AIcoreEngine", graph_opt)); | |||
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test_graph"); | |||
GraphOptimize base_optimize; | |||
ret = base_optimize.OptimizeOriginalGraph(compute_graph); | |||
EXPECT_EQ(ret, SUCCESS); | |||
ret = base_optimize.OptimizeOriginalGraphJudgeInsert(compute_graph); | |||
EXPECT_EQ(ret, SUCCESS); | |||
ret = base_optimize.OptimizeOriginalGraphForQuantize(compute_graph); | |||
EXPECT_EQ(ret, SUCCESS); | |||
ret = base_optimize.OptimizeGraphBeforeBuildForRts(compute_graph); | |||
EXPECT_EQ(ret, SUCCESS); | |||
ret = base_optimize.OptimizeWholeGraph(compute_graph); | |||
EXPECT_EQ(ret, SUCCESS); | |||
ret = instance_ptr->Finalize(); | |||
EXPECT_EQ(ret, SUCCESS); | |||
} | |||
TEST_F(UtestGraphOptimizeTest, test_optimizers_fail) { | |||
map<string, string> options; | |||
Status ret = ge::GELib::Initialize(options); | |||
EXPECT_EQ(ret, SUCCESS); | |||
shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||
EXPECT_NE(instance_ptr, nullptr); | |||
GraphOptimizerPtr graph_opt = MakeShared<TestGraphOptimizerFail>(); | |||
instance_ptr->opsManager_.graph_optimizers_by_priority_.push_back(make_pair("AIcoreEngine", graph_opt)); | |||
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test_graph"); | |||
GraphOptimize base_optimize; | |||
ret = base_optimize.OptimizeOriginalGraph(compute_graph); | |||
EXPECT_EQ(ret, FAILED); | |||
ret = base_optimize.OptimizeOriginalGraphJudgeInsert(compute_graph); | |||
EXPECT_EQ(ret, FAILED); | |||
ret = base_optimize.OptimizeOriginalGraphForQuantize(compute_graph); | |||
EXPECT_EQ(ret, FAILED); | |||
ret = base_optimize.OptimizeGraphBeforeBuildForRts(compute_graph); | |||
EXPECT_EQ(ret, FAILED); | |||
ret = base_optimize.OptimizeWholeGraph(compute_graph); | |||
EXPECT_EQ(ret, FAILED); | |||
ret = instance_ptr->Finalize(); | |||
EXPECT_EQ(ret, SUCCESS); | |||
} |
@@ -0,0 +1,150 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include <cstdint> | |||
#include <string> | |||
#include <gtest/gtest.h> | |||
#define protected public | |||
#define private public | |||
#include "graph/optimize/graph_optimize.h" | |||
#undef protected | |||
#undef private | |||
#include "../passes/graph_builder_utils.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
namespace ge { | |||
class UTest_Graph_Mem_RW_Conflict_Optimize : public testing::Test { | |||
protected: | |||
void SetUp() {} | |||
void TearDown() {} | |||
}; | |||
namespace { | |||
/* | |||
* Data -cast - netoutput | |||
*/ | |||
ComputeGraphPtr BuildGraph_Readonly_Subgraph(const string subraph_name){ | |||
auto sub_builder = ut::GraphBuilder(subraph_name); | |||
auto data1 = sub_builder.AddNode("data1", DATA, 0,1); | |||
auto cast = sub_builder.AddNode("cast", CAST, 1,1); | |||
auto netoutput = sub_builder.AddNode("netoutput",NETOUTPUT, 1,1); | |||
AttrUtils::SetInt(data1->GetOpDesc(),ATTR_NAME_PARENT_NODE_INDEX, 1); | |||
AttrUtils::SetInt(netoutput->GetOpDesc(),ATTR_NAME_PARENT_NODE_INDEX,0); | |||
sub_builder.AddDataEdge(data1,0,cast,0); | |||
sub_builder.AddDataEdge(cast,0,netoutput,0); | |||
return sub_builder.GetGraph(); | |||
} | |||
/* | |||
* const - allreduce | |||
* \ if | |||
* insert identity | |||
*/ | |||
ComputeGraphPtr BuildGraph_Readonly_ScopeWrite() { | |||
auto builder = ut::GraphBuilder("test"); | |||
auto const1 = builder.AddNode("const1", CONSTANT, 0, 1); | |||
auto ctrl_const = builder.AddNode("ctrl_const", CONSTANT, 0, 1); | |||
auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1); | |||
auto if_node = builder.AddNode("if", IF, 1,0); | |||
builder.AddDataEdge(const1, 0, allreduce, 0); | |||
builder.AddDataEdge(const1, 0, if_node, 0); | |||
builder.AddControlEdge(ctrl_const, allreduce); | |||
auto root_graph = builder.GetGraph(); | |||
string subgraph_name = "then_branch"; | |||
ComputeGraphPtr then_branch_graph = BuildGraph_Readonly_Subgraph(subgraph_name); | |||
then_branch_graph->SetParentNode(if_node); | |||
then_branch_graph->SetParentGraph(root_graph); | |||
if_node->GetOpDesc()->AddSubgraphName(subgraph_name); | |||
if_node->GetOpDesc()->SetSubgraphInstanceName(0,subgraph_name); | |||
root_graph->AddSubgraph(subgraph_name, then_branch_graph); | |||
return root_graph; | |||
} | |||
/* const1---allreduce const1--identity - allreduce | |||
* / / | |||
* var-identity--cast1 ==> var-----cast1 | |||
* \ \ | |||
* if if | |||
*/ | |||
ComputeGraphPtr BuildGraph_Identiyt_Split(){ | |||
auto builder = ut::GraphBuilder("g1"); | |||
auto var = builder.AddNode("var", VARIABLE, 0, 1); | |||
auto identity = builder.AddNode("identity", IDENTITY, 1, 1); | |||
auto const1 = builder.AddNode("const1", CONSTANT, 0, 1); | |||
auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1); | |||
auto cast1 = builder.AddNode("cast1", CAST, 1, 1); | |||
auto if_node = builder.AddNode("if", IF, 1,0); | |||
builder.AddDataEdge(var, 0 , identity, 0); | |||
builder.AddDataEdge(identity, 0 , allreduce, 0); | |||
builder.AddDataEdge(identity, 0 , cast1, 0); | |||
builder.AddDataEdge(identity, 0 , if_node, 0); | |||
builder.AddControlEdge(const1, allreduce); | |||
auto root_graph = builder.GetGraph(); | |||
string subgraph_name = "then_branch"; | |||
ComputeGraphPtr then_branch_graph = BuildGraph_Readonly_Subgraph(subgraph_name); | |||
then_branch_graph->SetParentNode(if_node); | |||
then_branch_graph->SetParentGraph(root_graph); | |||
if_node->GetOpDesc()->AddSubgraphName(subgraph_name); | |||
if_node->GetOpDesc()->SetSubgraphInstanceName(0,subgraph_name); | |||
root_graph->AddSubgraph(subgraph_name, then_branch_graph); | |||
return root_graph; | |||
} | |||
/* | |||
* mul == allreduce | |||
* need insert identity | |||
*/ | |||
ComputeGraphPtr BuildGraph_mul_1To2_ScopeWrite() { | |||
auto builder = ut::GraphBuilder("test"); | |||
auto mul = builder.AddNode("mul", MUL, 2,1); | |||
auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 2,0); | |||
AttrUtils::SetBool(allreduce->GetOpDesc(), "_input_mutable", true); | |||
builder.AddDataEdge(mul,0,allreduce,0); | |||
builder.AddDataEdge(mul,0,allreduce,1); | |||
return builder.GetGraph(); | |||
} | |||
} // namespace | |||
// const -> allreduce | |||
// const -> Identity -> allreduce | |||
TEST(UtestGraphPassesHcclMemcpyPass, testReadonlyScopeWriteConflict) { | |||
ComputeGraphPtr graph = BuildGraph_Readonly_ScopeWrite(); | |||
GraphOptimize graph_optimizer; | |||
auto ret = graph_optimizer.HandleMemoryRWConflict(graph); | |||
EXPECT_EQ(ret, SUCCESS); | |||
auto allreduce = graph->FindNode("allreduce"); | |||
EXPECT_EQ(allreduce->GetInDataNodes().at(0)->GetType(), IDENTITY); | |||
} | |||
TEST(UtestGraphPassesHcclMemcpyPass, testIdentiytSplit) { | |||
ComputeGraphPtr graph = BuildGraph_Identiyt_Split(); | |||
GraphOptimize graph_optimizer; | |||
auto ret = graph_optimizer.HandleMemoryRWConflict(graph); | |||
EXPECT_EQ(ret, SUCCESS); | |||
auto allreduce = graph->FindNode("allreduce"); | |||
auto allreduce_in_node = allreduce->GetInDataNodes().at(0); | |||
EXPECT_EQ(allreduce_in_node->GetType(), IDENTITY); | |||
EXPECT_EQ(allreduce_in_node->GetInControlNodes().at(0)->GetType(), CONSTANT); | |||
} | |||
TEST(UtestGraphPassesHcclMemcpyPass, testMul_1To2_ScopeWrite) { | |||
ComputeGraphPtr graph = BuildGraph_mul_1To2_ScopeWrite(); | |||
EXPECT_EQ(graph->GetDirectNodesSize(), 2); | |||
GraphOptimize graph_optimizer; | |||
auto ret = graph_optimizer.HandleMemoryRWConflict(graph); | |||
EXPECT_EQ(ret, SUCCESS); | |||
EXPECT_EQ(graph->GetDirectNodesSize(), 3); | |||
} | |||
} // namespace ge |
@@ -0,0 +1,79 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include <cstdint> | |||
#include <string> | |||
#include <gtest/gtest.h> | |||
#include "common/ge_inner_error_codes.h" | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/hccl_continuous_memcpy_pass.h" | |||
#undef protected | |||
#undef private | |||
#include "graph_builder_utils.h" | |||
namespace ge { | |||
class UtestGraphPassesHcclContinuousMemcpyPass : public testing::Test { | |||
protected: | |||
void SetUp() {} | |||
void TearDown() {} | |||
}; | |||
namespace { | |||
/* | |||
* var var | |||
* | \ | \ | |||
* | assign | assign | |||
* | // =======> | // | |||
* allreduce identity | |||
* | | | |||
* netoutput allreduce | |||
* | | |||
* netoutput | |||
*/ | |||
ComputeGraphPtr BuildGraph_Allreduce_Read_Var_After_Assign(){ | |||
auto builder = ut::GraphBuilder("test"); | |||
auto var = builder.AddNode("var", VARIABLE, 0, 1); | |||
auto assign = builder.AddNode("assign", ASSIGN, 1, 1); | |||
auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1); | |||
auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0); | |||
builder.AddDataEdge(var, 0, assign, 0); | |||
builder.AddDataEdge(var,0,allreduce,0); | |||
builder.AddControlEdge(assign, allreduce); | |||
return builder.GetGraph(); | |||
} | |||
} // namespace | |||
// const -> allreduce | |||
// const -> Identity -> allreduce | |||
TEST(UtestGraphPassesHcclContinuousMemcpyPass, testInsertIdentityBeforeHccl) { | |||
ComputeGraphPtr graph = BuildGraph_Allreduce_Read_Var_After_Assign(); | |||
auto src_node = graph->FindNode("var"); | |||
auto dst_node = graph->FindNode("allreduce"); | |||
// test InsertIdentityBeforeHccl | |||
HcclContinuousMemcpyPass hccl_continuous_memcpy_pass; | |||
hccl_continuous_memcpy_pass.InsertIdentityBeforeHccl(graph, src_node->GetOutDataAnchor(0), dst_node->GetInDataAnchor(0)); | |||
// check | |||
dst_node = graph->FindNode("allreduce"); | |||
auto in_node_before_dst_node = dst_node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); | |||
EXPECT_EQ(in_node_before_dst_node->GetType(), IDENTITY); | |||
EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().size(), 1); | |||
EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().at(0)->GetName(), "assign"); | |||
} | |||
} // namespace ge |
@@ -0,0 +1,80 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include <cstdint> | |||
#include <string> | |||
#include <gtest/gtest.h> | |||
#include "common/ge_inner_error_codes.h" | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/hccl_memcpy_pass.h" | |||
#undef protected | |||
#undef private | |||
#include "graph_builder_utils.h" | |||
namespace ge { | |||
class UtestGraphPassesHcclMemcpyPass : public testing::Test { | |||
protected: | |||
void SetUp() {} | |||
void TearDown() {} | |||
}; | |||
namespace { | |||
/* | |||
* var var | |||
* | \ | \ | |||
* | assign | assign | |||
* | // =======> | // | |||
* allreduce identity | |||
* | | | |||
* netoutput allreduce | |||
* | | |||
* netoutput | |||
*/ | |||
ComputeGraphPtr BuildGraph_Allreduce_Read_Var_After_Assign(){ | |||
auto builder = ut::GraphBuilder("test"); | |||
auto var = builder.AddNode("var", VARIABLE, 0, 1); | |||
auto assign = builder.AddNode("assign", ASSIGN, 1, 1); | |||
auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1); | |||
auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0); | |||
builder.AddDataEdge(var, 0, assign, 0); | |||
builder.AddDataEdge(var,0,allreduce,0); | |||
builder.AddControlEdge(assign, allreduce); | |||
return builder.GetGraph(); | |||
} | |||
} // namespace | |||
// const -> allreduce | |||
// const -> Identity -> allreduce | |||
TEST(UtestGraphPassesHcclMemcpyPass, testInsertIdentityBeforeHccl) { | |||
ComputeGraphPtr graph = BuildGraph_Allreduce_Read_Var_After_Assign(); | |||
auto src_node = graph->FindNode("var"); | |||
auto dst_node = graph->FindNode("allreduce"); | |||
// test InsertIdentityBeforeHccl | |||
HcclMemcpyPass hccl_memcpy_pass; | |||
hccl_memcpy_pass.InsertIdentityBeforeHccl(graph, src_node->GetOutDataAnchor(0), | |||
dst_node->GetInDataAnchor(0)); | |||
// check | |||
dst_node = graph->FindNode("allreduce"); | |||
auto in_node_before_dst_node = dst_node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); | |||
EXPECT_EQ(in_node_before_dst_node->GetType(), IDENTITY); | |||
EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().size(), 1); | |||
EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().at(0)->GetName(), "assign"); | |||
} | |||
} // namespace ge |
@@ -86,7 +86,7 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt | |||
* | | |||
* Merge | |||
* / \. | |||
* / \. | |||
* Active / \ Active | |||
* / \. | |||
* Add Sub | |||
* | \ / | | |||
@@ -96,8 +96,8 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt | |||
* Switch Switch | |||
* | \ / | | |||
* | \ / | | |||
* | \ / | | |||
* | \ / | | |||
* | Active | | |||
* | \ / | | |||
* | Less | | |||
* | / \ | | |||
* | / \ | | |||
@@ -127,7 +127,7 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt | |||
AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight); | |||
} | |||
const auto less1 = CreateNode(graph, "less", ENTER, 2, 1); | |||
const auto less1 = CreateNode(graph, "less", EXIT, 2, 1); // Mock for less, just pass input0. | |||
const auto active1 = CreateNode(graph, "active1", STREAMACTIVE, 0, 0); | |||
switch_t = CreateNode(graph, "switch_t", STREAMSWITCH, 2, 0); | |||
@@ -135,13 +135,14 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt | |||
AttrUtils::SetInt(switch_t->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_EQUAL); // 101 for true. | |||
AttrUtils::SetInt(switch_f->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_NOT_EQUAL); | |||
const auto add1 = CreateNode(graph, "add", ENTER, 2, 1); | |||
const auto sub1 = CreateNode(graph, "sub", ENTER, 2, 1); | |||
const auto add1 = CreateNode(graph, "add", EXIT, 2, 1); // Mock for add, just pass input0. | |||
const auto sub1 = CreateNode(graph, "sub", EXIT, 2, 1); // Mock for sub, just pass input0. | |||
const auto merge1 = CreateNode(graph, "merge", STREAMMERGE, 2, 2); | |||
const auto active2 = CreateNode(graph, "active2", STREAMACTIVE, 0, 0); | |||
const auto active3 = CreateNode(graph, "active3", STREAMACTIVE, 0, 0); | |||
const auto iteration1 = CreateNode(graph, "iteration1", NEXTITERATION, 1, 1); | |||
const auto output1 = CreateNode(graph, "net_output", NETOUTPUT, 1, 1); | |||
output1->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); | |||
@@ -170,7 +171,8 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt | |||
GraphUtils::AddEdge(sub1->GetOutControlAnchor(), active3->GetInControlAnchor()); | |||
GraphUtils::AddEdge(active3->GetOutControlAnchor(), merge1->GetInControlAnchor()); | |||
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), iteration1->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(iteration1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | |||
} | |||
TEST_F(UtestSubgraphExecutor, simple_schedule_tasks) { | |||
@@ -28,6 +28,7 @@ | |||
#include "graph/utils/graph_utils.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/ge_local_context.h" | |||
#include "graph/common/omg_util.h" | |||
using namespace std; | |||
using namespace testing; | |||
@@ -157,7 +158,7 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||
GraphUtils::AddEdge(next1->GetOutControlAnchor(), active3->GetInControlAnchor()); | |||
GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | |||
AttrUtils::SetStr(merge1->GetOpDesc(), ATTR_NAME_NEXT_ITERATION, next1->GetName()); | |||
SetNextIteration(merge1, next1); | |||
AttrUtils::SetBool(enter1->GetOpDesc(), ATTR_NAME_INSERT_FP_PROFILILNG_TASK, true); | |||
AttrUtils::SetBool(output1->GetOpDesc(), ATTR_NAME_INSERT_BP_PROFILILNG_TASK, true); | |||
@@ -0,0 +1,114 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include <gtest/gtest.h> | |||
#include <gmock/gmock.h> | |||
#include <vector> | |||
#define private public | |||
#define protected public | |||
#include "hybrid/executor/subgraph_context.h" | |||
#include "hybrid/node_executor/ge_local/ge_local_node_executor.h" | |||
#include "model/ge_root_model.h" | |||
#undef protected | |||
#undef private | |||
using namespace std; | |||
using namespace testing; | |||
namespace ge { | |||
using namespace hybrid; | |||
class UtestGeLocalNodeExecutor : public testing::Test { | |||
protected: | |||
void SetUp() {} | |||
void TearDown() { } | |||
}; | |||
static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { | |||
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | |||
op_desc->SetStreamId(0); | |||
static int32_t index = 0; | |||
op_desc->SetId(index++); | |||
GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64); | |||
TensorUtils::SetSize(tensor, 64); | |||
vector<int64_t> input_offset; | |||
for (int i = 0; i < in_num; i++) { | |||
op_desc->AddInputDesc(tensor); | |||
input_offset.emplace_back(i * 64); | |||
} | |||
op_desc->SetInputOffset(input_offset); | |||
vector<int64_t> output_offset; | |||
for (int i = 0; i < out_num; i++) { | |||
op_desc->AddOutputDesc(tensor); | |||
output_offset.emplace_back(in_num * 64 + i * 64); | |||
} | |||
op_desc->SetOutputOffset(output_offset); | |||
op_desc->SetWorkspace({}); | |||
op_desc->SetWorkspaceBytes({}); | |||
op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); | |||
return graph.AddNode(op_desc); | |||
} | |||
TEST_F(UtestGeLocalNodeExecutor, test_no_op_task) { | |||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||
GeModelPtr ge_sub_model = std::make_shared<GeModel>(); | |||
GeRootModelPtr ge_root_model = std::make_shared<GeRootModel>(graph); | |||
ge_root_model->SetModelName("test_name"); | |||
ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); | |||
HybridModel hybrid_model(ge_root_model); | |||
NodePtr node = CreateNode(*graph, "noop", NOOP, 0, 0); | |||
std::unique_ptr<NodeItem> new_node; | |||
ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); | |||
NodeItem *node_item = new_node.get(); | |||
hybrid_model.node_items_[node] = std::move(new_node); | |||
node_item->input_start = 0; | |||
node_item->output_start = 0; | |||
GraphItem graph_item; | |||
graph_item.node_items_.emplace_back(node_item); | |||
graph_item.total_inputs_ = 0; | |||
graph_item.total_outputs_ = 0; | |||
GraphExecutionContext graph_context; | |||
SubgraphContext subgraph_context(&graph_item, &graph_context); | |||
ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); | |||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||
ASSERT_NE(node_state, nullptr); | |||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||
ASSERT_NE(unique_task_context, nullptr); | |||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
node_state->SetTaskContext(shared_task_context); | |||
NodeTaskPtr task = nullptr; | |||
GeLocalNodeExecutor node_executor; | |||
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | |||
ASSERT_NE(task, nullptr); | |||
ASSERT_EQ(task->UpdateArgs(*node_state->GetTaskContext()), SUCCESS); | |||
std::function<void()> done = []() {}; | |||
ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS); | |||
} | |||
} // namespace ge |