diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index f83d2607..398dafcb 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -166,6 +166,11 @@ set(EXECUTOR_SRC_LIST "graph/manager/util/debug.cc" #"graph/manager/util/hcom_util.cc" # Just for runner. "graph/passes/pass_utils.cc" + "graph/passes/mds_pass.cc" + "graph/passes/mds_kernels/mds_utils.cc" + "graph/passes/mds_kernels/variable_mds_kernel.cc" + "graph/passes/mds_kernels/conv2d_mds_kernel.cc" + "graph/passes/mds_kernels/base_mds_kernel.cc" "host_kernels/add_kernel.cc" "host_kernels/broadcast_args_kernel.cc" "host_kernels/broadcast_gradient_args_kernel.cc" diff --git a/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc b/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc index 7f2de735..d93463af 100755 --- a/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc +++ b/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc @@ -77,9 +77,11 @@ Status GeLocalOpsKernelInfoStore::DestroySession(const map &sess return SUCCESS; } Status GeLocalOpsKernelInfoStore::SetCutSupportedInfo(const NodePtr &node) { - //TODO:1,针对变量类型是否标识为可训练变量 - //2,是否开启smdp1和3 - //满足上述两点,在变量切分信息里面设置为当前变量节点可切 + // TODO: + // 1. Whether the variable type is identified as a trainable variable + // 2, whether to turn on smdp1 and 3 + // To meet the above two points, set the current variable + // node to be tangent in the variable segmentation information return SUCCESS; } } // namespace ge_local diff --git a/ge/graph/execute/graph_execute.cc b/ge/graph/execute/graph_execute.cc index 49cda128..6a53c51c 100755 --- a/ge/graph/execute/graph_execute.cc +++ b/ge/graph/execute/graph_execute.cc @@ -540,14 +540,12 @@ Status GraphExecutor::SetCallback(uint32_t model_id, const GeRootModelPtr &ge_ro return SUCCESS; } Status GraphExecutor::AsyncMultiExecuteModel(const GeRootModelPtr &ge_root_model, const std::vector &inputs, - const RunAsyncCallback &callback) { + const RunAsyncCallback &callback) { // get deploy number of model instance auto root_graph = ge_root_model->GetRootGraph(); vector deploy_info; if (!ge::AttrUtils::GetListNamedAttrs(root_graph, ATTR_NAME_DEPLOY_INFO, deploy_info) || deploy_info.empty()) { - GELOGE(FAILED, - "[AsyncMultiExecuteModel] graph %s has invalid deploy attr %s", - root_graph->GetName().c_str(), + GELOGE(FAILED, "[AsyncMultiExecuteModel] graph %s has invalid deploy attr %s", root_graph->GetName().c_str(), ATTR_NAME_DEPLOY_INFO.c_str()); return FAILED; } @@ -567,7 +565,7 @@ Status GraphExecutor::AsyncMultiExecuteModel(const GeRootModelPtr &ge_root_model std::vector graph_inputs; if (ge::AttrUtils::MutableListTensor(thread_instance, kAttrGraphInputs, graph_inputs)) { std::vector graph_input_updated(inputs.begin(), inputs.end()); - for (const auto &ge_tensor_ptr:graph_inputs) { + for (const auto &ge_tensor_ptr : graph_inputs) { graph_input_updated.push_back(TensorAdapter::AsTensor(*ge_tensor_ptr)); } GraphExecutor graph_executor; @@ -575,12 +573,12 @@ Status GraphExecutor::AsyncMultiExecuteModel(const GeRootModelPtr &ge_root_model std::future f; bool need_return_result = false; if ((ge::AttrUtils::GetBool(thread_instance, kAttrNeedReturnResult, need_return_result) && need_return_result)) { - f = executor.commit(execute_model_func, &graph_executor, ge_root_model, - model_ids[i], graph_input_updated, callback); + f = executor.commit(execute_model_func, &graph_executor, ge_root_model, model_ids[i], graph_input_updated, + callback); } else { RunAsyncCallback callback_stub; - f = executor.commit(execute_model_func, &graph_executor, ge_root_model, - model_ids[i], graph_input_updated, callback_stub); + f = executor.commit(execute_model_func, &graph_executor, ge_root_model, model_ids[i], graph_input_updated, + callback_stub); } if (!f.valid()) { GELOGE(FAILED, "[Call][Commit] failed, Future is invalid"); @@ -600,8 +598,8 @@ Status GraphExecutor::AsyncMultiExecuteModel(const GeRootModelPtr &ge_root_model return SUCCESS; } -Status GraphExecutor::AsyncExecuteModel(const GeRootModelPtr &ge_root_model, uint32_t model_id, const std::vector &inputs, - const RunAsyncCallback &callback) { +Status GraphExecutor::AsyncExecuteModel(const GeRootModelPtr &ge_root_model, uint32_t model_id, + const std::vector &inputs, const RunAsyncCallback &callback) { if (model_id == kInvalidModelId) { GELOGE(INTERNAL_ERROR, "No valid model id."); return INTERNAL_ERROR; diff --git a/ge/graph/load/graph_loader.cc b/ge/graph/load/graph_loader.cc index 6a007dd2..df60df4c 100755 --- a/ge/graph/load/graph_loader.cc +++ b/ge/graph/load/graph_loader.cc @@ -149,17 +149,14 @@ Status GraphLoader::LoadMultiModelOnline(const std::shared_ptr auto root_graph = ge_root_model->GetRootGraph(); vector deploy_info; if (!ge::AttrUtils::GetListNamedAttrs(root_graph, ATTR_NAME_DEPLOY_INFO, deploy_info) || deploy_info.empty()) { - GELOGE(FAILED, - "[LoadMultiModelOnline] Load multi model failed, graph %s has invalid deploy attr %s", - root_graph->GetName().c_str(), - ATTR_NAME_DEPLOY_INFO.c_str()); + GELOGE(FAILED, "[LoadMultiModelOnline] Load multi model failed, graph %s has invalid deploy attr %s", + root_graph->GetName().c_str(), ATTR_NAME_DEPLOY_INFO.c_str()); return FAILED; } auto thread_instances_size = deploy_info.size(); auto device_id_fission_from = GetContext().DeviceId(); - GELOGI("Graph %s need to load model %zu times, and fission from device %u.", - root_graph->GetName().c_str(), thread_instances_size, - device_id_fission_from); + GELOGI("Graph %s need to load model %zu times, and fission from device %u.", root_graph->GetName().c_str(), + thread_instances_size, device_id_fission_from); ThreadPool executor(thread_instances_size); std::vector> vector_future; GE_TIMESTAMP_START(LoadModelOnline); @@ -167,17 +164,16 @@ Status GraphLoader::LoadMultiModelOnline(const std::shared_ptr auto thread_instance = deploy_info[i]; std::string device_type; ModelIdInfo model_id_info; - //TODO: listener要区分同步异步 std::shared_ptr listener; if (is_async) { listener = MakeShared(); GE_CHECK_NOTNULL(listener); } else { - + // TODO: GraphModelListener for sync } int64_t device_id_fissioned = kInvalidDieId; - if (!ge::AttrUtils::GetInt(thread_instance, kAttrDeviceId, device_id_fissioned) - || device_id_fissioned == kInvalidDieId) { + if (!ge::AttrUtils::GetInt(thread_instance, kAttrDeviceId, device_id_fissioned) || + device_id_fissioned == kInvalidDieId) { REPORT_CALL_ERROR("E19999", "graph %s has invalid deploy attr %s", root_graph->GetName().c_str(), ATTR_NAME_DEPLOY_INFO.c_str()); GELOGE(GRAPH_FAILED, "[LoadMultiModelOnline] graph %s has invalid deploy attr %s", root_graph->GetName().c_str(), diff --git a/ge/graph/passes/mds_kernels/base_mds_kernel.cc b/ge/graph/passes/mds_kernels/base_mds_kernel.cc index 602b021f..53f47ac4 100644 --- a/ge/graph/passes/mds_kernels/base_mds_kernel.cc +++ b/ge/graph/passes/mds_kernels/base_mds_kernel.cc @@ -26,8 +26,7 @@ shared_ptr GetKernelByType(const NodePtr &node) { string type = node->GetType(); if (type == FRAMEWORKOP) { if (!ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type)) { - REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", - ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE.c_str(), + REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE.c_str(), node->GetName().c_str(), node->GetType().c_str()); return nullptr; } @@ -35,10 +34,10 @@ shared_ptr GetKernelByType(const NodePtr &node) { return factory.Create(type); } -} +} // namespace mds_cut_pass shared_ptr DeploySchedulerKernel::Instance() { static const std::shared_ptr instance_ptr = - shared_ptr(new(std::nothrow) DeploySchedulerKernel()); + shared_ptr(new (std::nothrow) DeploySchedulerKernel()); return instance_ptr; } @@ -63,20 +62,15 @@ Status DeploySchedulerKernel::CutN(const ge::NodePtr &node) { if (MdsUtils::IsDistributedDeploySupported(tensor_desc, kCutN)) { tensor_desc->SetShape(src_tensor_desc->GetShape()); } else { - MDS_REQUIRE_SUCCESS(MdsUtils::DataGather(src_anchor, in_anchor), - "[CutN] failed to gather between node[%s][%d] to node[%s][%d]", - src_op_desc->GetName().c_str(), - src_anchor->GetIdx(), - op_desc->GetName().c_str(), - in_anchor->GetIdx()); + MDS_REQUIRE_SUCCESS( + MdsUtils::DataGather(src_anchor, in_anchor), "[CutN] failed to gather between node[%s][%d] to node[%s][%d]", + src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), in_anchor->GetIdx()); } } else { if (MdsUtils::IsDistributedDeploySupported(tensor_desc, kCutN)) { MDS_REQUIRE_SUCCESS(MdsUtils::DataSlice(src_anchor, in_anchor, input_node_), "[CutN] failed to slice between node[%s][%d] to node[%s][%d]", - src_op_desc->GetName().c_str(), - src_anchor->GetIdx(), - op_desc->GetName().c_str(), + src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), in_anchor->GetIdx()); } else { tensor_desc->SetShape(src_tensor_desc->GetShape()); @@ -85,14 +79,11 @@ Status DeploySchedulerKernel::CutN(const ge::NodePtr &node) { // insert hcomallreduce for cutn bool is_grad_compute_node = false; - if (ge::AttrUtils::GetBool(src_node->GetOpDesc(), ATTR_NAME_GRADIENT_NODE, is_grad_compute_node) - && is_grad_compute_node) { - MDS_REQUIRE_SUCCESS(MdsUtils::DataReduce(src_anchor, in_anchor), - "[CutN] failed to reduce between node[%s][%d] to node[%s][%d]", - src_op_desc->GetName().c_str(), - src_anchor->GetIdx(), - op_desc->GetName().c_str(), - in_anchor->GetIdx()); + if (ge::AttrUtils::GetBool(src_node->GetOpDesc(), ATTR_NAME_GRADIENT_NODE, is_grad_compute_node) && + is_grad_compute_node) { + MDS_REQUIRE_SUCCESS( + MdsUtils::DataReduce(src_anchor, in_anchor), "[CutN] failed to reduce between node[%s][%d] to node[%s][%d]", + src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), in_anchor->GetIdx()); } } @@ -122,32 +113,23 @@ Status DeploySchedulerKernel::CutH(const ge::NodePtr &node) { if (MdsUtils::IsDistributedDeploySupported(tensor_desc, kCutH)) { MDS_REQUIRE_SUCCESS(HaloExchangeProcess(node, in_anchor->GetIdx()), "[CutH] failed to do overlap between node[%s][%d] to node[%s][%d]", - src_op_desc->GetName().c_str(), - src_anchor->GetIdx(), - op_desc->GetName().c_str(), + src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), in_anchor->GetIdx()); } else { - MDS_REQUIRE_SUCCESS(MdsUtils::DataGather(src_anchor, in_anchor), - "[CutH] failed to gather between node[%s][%d] to node[%s][%d]", - src_op_desc->GetName().c_str(), - src_anchor->GetIdx(), - op_desc->GetName().c_str(), - in_anchor->GetIdx()); + MDS_REQUIRE_SUCCESS( + MdsUtils::DataGather(src_anchor, in_anchor), "[CutH] failed to gather between node[%s][%d] to node[%s][%d]", + src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), in_anchor->GetIdx()); } } else { if (MdsUtils::IsDistributedDeploySupported(tensor_desc, kCutH)) { MDS_REQUIRE_SUCCESS(MdsUtils::DataSlice(src_anchor, in_anchor, input_node_), "[CutH] failed to slice between node[%s][%d] to node[%s][%d]", - src_op_desc->GetName().c_str(), - src_anchor->GetIdx(), - op_desc->GetName().c_str(), + src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), in_anchor->GetIdx()); } else { MDS_REQUIRE_SUCCESS(HaloExchangeProcess(node, in_anchor->GetIdx(), true), "[CutH] failed to do overlap between node[%s][%d] to node[%s][%d]", - src_op_desc->GetName().c_str(), - src_anchor->GetIdx(), - op_desc->GetName().c_str(), + src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), in_anchor->GetIdx()); } } @@ -157,4 +139,4 @@ Status DeploySchedulerKernel::CutH(const ge::NodePtr &node) { MDS_REQUIRE_SUCCESS(node->InferShapeAndType(), "[CutH] call infer shape failed", node->GetName().c_str()); return SUCCESS; } -} +} // namespace ge diff --git a/ge/graph/passes/mds_kernels/base_mds_kernel.h b/ge/graph/passes/mds_kernels/base_mds_kernel.h index 56670c1e..8f46a696 100644 --- a/ge/graph/passes/mds_kernels/base_mds_kernel.h +++ b/ge/graph/passes/mds_kernels/base_mds_kernel.h @@ -55,12 +55,16 @@ class DeploySchedulerKernel { // halo exchange process Status HaloExchangeProcess(NodePtr node, int64_t index, bool local_slice = false); - NodePtr GetInputNode() { return input_node_; } + NodePtr GetInputNode() { + return input_node_; + } DeploySchedulerKernel &operator=(const DeploySchedulerKernel &kernel) = delete; DeploySchedulerKernel(const DeploySchedulerKernel &kernel) = delete; + protected: DeploySchedulerKernel() = default; virtual ~DeploySchedulerKernel() = default; + private: NodePtr input_node_ = nullptr; }; @@ -69,4 +73,4 @@ namespace mds_cut_pass { shared_ptr GetKernelByType(const NodePtr &node); } } // namespace ge -#endif //MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_BASE_MDS_KERNEL_H_ \ No newline at end of file +#endif // MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_BASE_MDS_KERNEL_H_ \ No newline at end of file diff --git a/ge/graph/passes/mds_kernels/mds_utils.cc b/ge/graph/passes/mds_kernels/mds_utils.cc index 58ab0ef6..b942b8ab 100644 --- a/ge/graph/passes/mds_kernels/mds_utils.cc +++ b/ge/graph/passes/mds_kernels/mds_utils.cc @@ -21,39 +21,48 @@ thread_local int64_t data_slice_count = 0; thread_local int64_t data_gather_count = 0; thread_local int64_t data_reduce_count = 0; const std::string kPrefix = "mds"; -} +} // namespace int64_t MdsUtils::GetNLocation(Format fmt) { int64_t loc = kNInvalidLocation; switch (fmt) { - case FORMAT_NCHW : - case FORMAT_NHWC :loc = kNLocation0; + case FORMAT_NCHW: + case FORMAT_NHWC: + loc = kNLocation0; break; case FORMAT_CHWN: - case FORMAT_HWCN:loc = kNLocation3; + case FORMAT_HWCN: + loc = kNLocation3; break; - default:GELOGE(FAILED, "[MDS]unsupported format:%d %s", fmt, TypeUtils::FormatToSerialString(fmt).c_str()); + default: + GELOGE(FAILED, "[MDS]unsupported format:%d %s", fmt, TypeUtils::FormatToSerialString(fmt).c_str()); } return loc; } int64_t MdsUtils::GetHLocation(Format fmt) { int64_t loc = kHInvalidLocation; switch (fmt) { - case FORMAT_HWCN :loc = kHLocation0; + case FORMAT_HWCN: + loc = kHLocation0; break; - case FORMAT_NHWC : - case FORMAT_CHWN :loc = kHLocation1; + case FORMAT_NHWC: + case FORMAT_CHWN: + loc = kHLocation1; break; - case FORMAT_NCHW :loc = kHLocation2; - default :GELOGE(FAILED, "[MDS]unsupported format:%d %s", fmt, TypeUtils::FormatToSerialString(fmt).c_str()); + case FORMAT_NCHW: + loc = kHLocation2; + default: + GELOGE(FAILED, "[MDS]unsupported format:%d %s", fmt, TypeUtils::FormatToSerialString(fmt).c_str()); } return loc; } int64_t MdsUtils::GetIndexByFormat(const GeTensorDescPtr &ge_tensor_desc, CutType type) { Format fmt = ge_tensor_desc->GetFormat(); switch (type) { - case kCutN :return GetNLocation(fmt); - case kCutH :return GetHLocation(fmt); - default :; + case kCutN: + return GetNLocation(fmt); + case kCutH: + return GetHLocation(fmt); + default:; } GELOGE(FAILED, "[MDS]invalid CutType:%d", type); return kInvalidIndex; @@ -64,7 +73,7 @@ bool MdsUtils::IsDistributedDeploySupported(const GeTensorDescPtr &ge_tensor_des GELOGE(FAILED, "[MDS]invalid input param: tensor is null!"); return false; } - if (type != kCutN || type != kCutH) { + if (type != kCutN && type != kCutH) { REPORT_INNER_ERROR("E19999", "invalid CutType:%d", type); GELOGE(FAILED, "[MDS]invalid CutType:%d", type); return false; @@ -75,6 +84,18 @@ bool MdsUtils::IsDistributedDeploySupported(const GeTensorDescPtr &ge_tensor_des GELOGE(FAILED, "[MDS]", "invalid index param:%ld", cut_index); return false; } + auto dims = ge_tensor_desc->GetShape().GetDims(); + if (cut_index < 0 || cut_index >= dims.size()) { + REPORT_INNER_ERROR("E19999", "cut_index %ld for CutType %d is out of range of dims size %zu", cut_index, type, + dims.size()); + GELOGE(FAILED, "[MDS]", "cut_index %ld for CutType %d is out of range of dims size %zu", cut_index, type, + dims.size()); + return false; + } + if (dims[cut_index] % kDeployNumber != 0) { + GELOGW("[MDS] cut_index %ld for CutType %d with dim %ld can not deploy", cut_index, type, dims[cut_index]); + return false; + } vector cut_support_info; if (!(AttrUtils::GetListInt(*ge_tensor_desc, ATTR_NAME_CUT_INFO, cut_support_info))) { REPORT_INNER_ERROR("E19999", "call GetlistInt failed"); @@ -82,10 +103,10 @@ bool MdsUtils::IsDistributedDeploySupported(const GeTensorDescPtr &ge_tensor_des return false; } if (cut_index < 0 || cut_index >= cut_support_info.size()) { - REPORT_INNER_ERROR("E19999", "cut_index %ld for CutType %d is out of range of cut_support_info size %zu", - cut_index, type, cut_support_info.size()); - GELOGE(FAILED, "[MDS]", "cut_index %ld for CutType %d is out of range of cut_support_info size %zu", - cut_index, type, cut_support_info.size()); + REPORT_INNER_ERROR("E19999", "cut_index %ld for CutType %d is out of range of cut_support_info size %zu", cut_index, + type, cut_support_info.size()); + GELOGE(FAILED, "[MDS]", "cut_index %ld for CutType %d is out of range of cut_support_info size %zu", cut_index, + type, cut_support_info.size()); return false; } if (cut_support_info[cut_index] < kNotSupport || cut_support_info[cut_index] > kAnyCutSupported) { @@ -95,8 +116,7 @@ bool MdsUtils::IsDistributedDeploySupported(const GeTensorDescPtr &ge_tensor_des } return cut_support_info[cut_index] & kSplitCutSupported; } -Status MdsUtils::DistributedDeploy(const GeTensorDescPtr &ge_tensor_desc, - CutType type, int64_t deploy_number) { +Status MdsUtils::DistributedDeploy(const GeTensorDescPtr &ge_tensor_desc, CutType type, int64_t deploy_number) { GE_CHECK_NOTNULL(ge_tensor_desc); auto index = MdsUtils::GetIndexByFormat(ge_tensor_desc, type); auto dims = ge_tensor_desc->GetShape().GetDims(); @@ -109,14 +129,10 @@ Status MdsUtils::DistributedDeploy(const GeTensorDescPtr &ge_tensor_desc, } Status MdsUtils::SetAttrForHcomNode(const OpDescPtr &hcom_op, int64_t fission_factor, const std::string &group_name) { GE_CHECK_NOTNULL(hcom_op); - REQUIRE(fission_factor > kDefaultFissionFactor, - "fission_factor %ld need be bigger than %ld", - fission_factor, + REQUIRE(fission_factor > kDefaultFissionFactor, "fission_factor %ld need be bigger than %ld", fission_factor, kDefaultFissionFactor); REQUIRE(ge::AttrUtils::SetInt(hcom_op, ATTR_NAME_FISSION_FACTOR, fission_factor), - "Failed to set attr fission_factor %ld for op:%s(%s)", - fission_factor, - hcom_op->GetName().c_str(), + "Failed to set attr fission_factor %ld for op:%s(%s)", fission_factor, hcom_op->GetName().c_str(), hcom_op->GetType().c_str()); if (!group_name.empty()) { @@ -131,7 +147,7 @@ bool MdsUtils::IsMDSNeeded() { GELOGI("[MDS]device type is %s, skip mds", device_type.c_str()); return false; } - //TODO:解析系统的配置文件得到exe unit + // TODO: Parse the configuration file of the system to get the sys_config_exe_unit std::string sys_config_exe_unit = "DIE"; return device_type != sys_config_exe_unit; } @@ -153,18 +169,17 @@ Status MdsUtils::SetDeployInfo(const ComputeGraphPtr &compute_graph, const NodeP GeAttrValue::NAMED_ATTRS thread_instance; thread_instance.SetName(std::to_string(device_id)); - (void) thread_instance.SetAttr(kAttrDeviceId, GeAttrValue::CreateFrom(device_id)); - // TODO:跟rts确认属性值 - (void) thread_instance.SetAttr(kAttrDeviceType, GeAttrValue::CreateFrom("MultiMode")); - (void) thread_instance.SetAttr(kAttrGraphName, GeAttrValue::CreateFrom(compute_graph->GetName())); - (void) thread_instance.SetAttr(kAttrGraphInputs, GeAttrValue::CreateFrom(graph_inputs)); + (void)thread_instance.SetAttr(kAttrDeviceId, GeAttrValue::CreateFrom(device_id)); + // TODO:Change to enumeration from RTS header file + (void)thread_instance.SetAttr(kAttrDeviceType, GeAttrValue::CreateFrom("MultiMode")); + (void)thread_instance.SetAttr(kAttrGraphName, GeAttrValue::CreateFrom(compute_graph->GetName())); + (void)thread_instance.SetAttr(kAttrGraphInputs, GeAttrValue::CreateFrom(graph_inputs)); deploy_info.emplace_back(thread_instance); GELOGD("[MDS]%s SetDeployInfo on device id: %d", compute_graph->GetName().c_str(), device_id); } // set deploy info REQUIRE(ge::AttrUtils::SetListNamedAttrs(*compute_graph, ATTR_NAME_DEPLOY_INFO, deploy_info), - "Set attr failed for graph %s", - compute_graph->GetName().c_str()); + "Set attr failed for graph %s", compute_graph->GetName().c_str()); return SUCCESS; } @@ -176,7 +191,7 @@ CutType MdsUtils::TryGetGraphCutType(const ComputeGraphPtr &compute_graph) { is_unknown_graph = true; } CutType selected_cut_type = kNoCut; - for (const auto &data: compute_graph->GetInputNodes()) { + for (const auto &data : compute_graph->GetInputNodes()) { GELOGI("Get graph input %s %s", data->GetName().c_str(), data->GetType().c_str()); auto data_n_index = MdsUtils::GetIndexByFormat(data->GetOpDesc()->MutableOutputDesc(0), kCutN); auto data_n_dim = data->GetOpDesc()->GetOutputDesc(0).GetShape().GetDim(data_n_index); @@ -197,8 +212,7 @@ CutType MdsUtils::TryGetGraphCutType(const ComputeGraphPtr &compute_graph) { return selected_cut_type; } Status MdsUtils::SetDeployInfo(const ComputeGraphPtr &compute_graph, - const std::multimap &deploys, - const std::string &device_type) { + const std::multimap &deploys, const std::string &device_type) { GE_CHECK_NOTNULL(compute_graph); GELOGD("[MDS]%s SetDeployInfo start", compute_graph->GetName().c_str()); @@ -208,19 +222,18 @@ Status MdsUtils::SetDeployInfo(const ComputeGraphPtr &compute_graph, int64_t device_id = pair.first; GeAttrValue::NAMED_ATTRS thread_instance; thread_instance.SetName(std::to_string(device_id)); - (void) thread_instance.SetAttr(kAttrNeedReturnResult, - GeAttrValue::CreateFrom(deploy_info.empty() ? true : false)); - (void) thread_instance.SetAttr(kAttrDeviceId, GeAttrValue::CreateFrom(device_id)); - (void) thread_instance.SetAttr(kAttrDeviceType, GeAttrValue::CreateFrom(device_type)); - (void) thread_instance.SetAttr(kAttrGraphName, GeAttrValue::CreateFrom(compute_graph->GetName())); - (void) thread_instance.SetAttr(kAttrGraphInputs, GeAttrValue::CreateFrom(pair.second)); + (void)thread_instance.SetAttr(kAttrNeedReturnResult, + GeAttrValue::CreateFrom(deploy_info.empty() ? true : false)); + (void)thread_instance.SetAttr(kAttrDeviceId, GeAttrValue::CreateFrom(device_id)); + (void)thread_instance.SetAttr(kAttrDeviceType, GeAttrValue::CreateFrom(device_type)); + (void)thread_instance.SetAttr(kAttrGraphName, GeAttrValue::CreateFrom(compute_graph->GetName())); + (void)thread_instance.SetAttr(kAttrGraphInputs, GeAttrValue::CreateFrom(pair.second)); deploy_info.emplace_back(thread_instance); GELOGD("[MDS]%s SetDeployInfo on device id: %d", compute_graph->GetName().c_str(), device_id); } // set deploy info REQUIRE(ge::AttrUtils::SetListNamedAttrs(*compute_graph, ATTR_NAME_DEPLOY_INFO, deploy_info), - "Set attr failed for graph %s", - compute_graph->GetName().c_str()); + "Set attr failed for graph %s", compute_graph->GetName().c_str()); return SUCCESS; } @@ -233,24 +246,20 @@ Status MdsUtils::DataGather(const OutDataAnchorPtr &src, const InDataAnchorPtr & auto src_graph = src_node->GetOwnerComputeGraph(); GE_CHECK_NOTNULL(src_graph); std::string node_name_suffix("_" + kPrefix + "_" + std::to_string(data_gather_count)); - auto - hcom_allgather_node = AddDynamicInputOutputNode(src_graph, HCOMALLGATHER, HCOMALLGATHER + node_name_suffix, 1, 1); + auto hcom_allgather_node = + AddDynamicInputOutputNode(src_graph, HCOMALLGATHER, HCOMALLGATHER + node_name_suffix, 1, 1); GE_CHECK_NOTNULL(hcom_allgather_node); MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(src, {dst}, hcom_allgather_node), - "[DataGather] failed between %s and %s", - src_node->GetName().c_str(), + "[DataGather] failed between %s and %s", src_node->GetName().c_str(), dst_node->GetName().c_str()); MDS_REQUIRE_SUCCESS(MdsUtils::SetAttrForHcomNode(hcom_allgather_node->GetOpDesc(), kDeployNumber, kDefaultGroup), - "[DataGather]set attr for node for %s(%s) failed", - hcom_allgather_node->GetName().c_str(), hcom_allgather_node->GetType().c_str()); + "[DataGather]set attr for node for %s(%s) failed", hcom_allgather_node->GetName().c_str(), + hcom_allgather_node->GetType().c_str()); REQUIRE(ge::AttrUtils::SetInt(hcom_allgather_node->GetOpDesc(), HCOM_ATTR_RANK_SIZE, kDefaultRankSize), - "Failed to set attr reduction type %s for op:%s(%s)", - kDefaultReduction.c_str(), - hcom_allgather_node->GetName().c_str(), - hcom_allgather_node->GetType().c_str()); + "Failed to set attr reduction type %s for op:%s(%s)", kDefaultReduction.c_str(), + hcom_allgather_node->GetName().c_str(), hcom_allgather_node->GetType().c_str()); MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(hcom_allgather_node, false), - "[DataGather] %s call infershape failed", - hcom_allgather_node->GetName().c_str()); + "[DataGather] %s call infershape failed", hcom_allgather_node->GetName().c_str()); data_gather_count++; return SUCCESS; } @@ -269,8 +278,7 @@ Status MdsUtils::DataReduce(const OutDataAnchorPtr &src, const InDataAnchorPtr & NodePtr all_reduce_node = nullptr; if (NeedInsertHcomAllReduce(src_node, all_reduce_node)) { MDS_REQUIRE_SUCCESS(ConstructReduceNode(src_graph, src, dst, all_reduce_node), - "[DataReduce] construct allreduce node for %s failed", - all_reduce_node->GetName().c_str()); + "[DataReduce] construct allreduce node for %s failed", all_reduce_node->GetName().c_str()); GE_CHECK_NOTNULL(all_reduce_node); } else { GE_CHECK_NOTNULL(all_reduce_node); @@ -300,25 +308,19 @@ Status MdsUtils::DataSlice(const OutDataAnchorPtr &src, const InDataAnchorPtr &d GeTensorDesc tensor = src_node->GetOpDesc()->GetOutputDesc(src->GetIdx()); NodePtr slice_node = nullptr; MDS_REQUIRE_SUCCESS(ConstructSliceNode(src_graph, tensor, input_node.get(), slice_node), - "[DataSlice] construct slice node for %s failed", - src_node->GetName().c_str()); + "[DataSlice] construct slice node for %s failed", src_node->GetName().c_str()); GE_CHECK_NOTNULL(slice_node); - MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(src, {dst}, slice_node), - "[DataSlice] failed between %s and %s", - src_node->GetName().c_str(), - dst_node->GetName().c_str()); - MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(slice_node, false), - "[DataSlice] %s call infer shape failed", + MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(src, {dst}, slice_node), "[DataSlice] failed between %s and %s", + src_node->GetName().c_str(), dst_node->GetName().c_str()); + MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(slice_node, false), "[DataSlice] %s call infer shape failed", slice_node->GetName().c_str()); return SUCCESS; } -Status MdsUtils::ConstructSliceNode(const ComputeGraphPtr &src_graph, - const GeTensorDesc &tensor, - Node *input_node, +Status MdsUtils::ConstructSliceNode(const ComputeGraphPtr &src_graph, const GeTensorDesc &tensor, Node *input_node, NodePtr &slice_node) { vector slice_sizes = tensor.GetShape().GetDims(); - // TODO: 用图结构表达 + // TODO: Express with graph structure slice_sizes[0] /= kDeployNumber; vector ge_tensors; GeTensorDesc ge_tensor_desc; @@ -346,10 +348,10 @@ Status MdsUtils::ConstructSliceNode(const ComputeGraphPtr &src_graph, GE_CHECK_NOTNULL(input_node); MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(input_node->GetOutDataAnchor(0), mul_node->GetInDataAnchor(0)), "[ConstructSliceNode] add edge failed"); - MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(const_node_slice_offset_first_dim->GetOutDataAnchor(0), - mul_node->GetInDataAnchor(1)), "[ConstructSliceNode] add edge failed"); - MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(mul_node, false), - "[DataSlice] %s call infer shape failed", + MDS_REQUIRE_SUCCESS( + GraphUtils::AddEdge(const_node_slice_offset_first_dim->GetOutDataAnchor(0), mul_node->GetInDataAnchor(1)), + "[ConstructSliceNode] add edge failed"); + MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(mul_node, false), "[DataSlice] %s call infer shape failed", mul_node->GetName().c_str()); NodePtr pack_node = AddDynamicInputOutputNode(src_graph, PACK, PACK + node_name_suffix, slice_sizes.size(), 1); bool is_first_input = true; @@ -363,8 +365,7 @@ Status MdsUtils::ConstructSliceNode(const ComputeGraphPtr &src_graph, "[ConstructSliceNode] add edge failed"); } } - MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(pack_node, false), - "[DataSlice] %s call infer shape failed", + MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(pack_node, false), "[DataSlice] %s call infer shape failed", pack_node->GetName().c_str()); slice_node = AddDynamicInputOutputNode(src_graph, SLICE, SLICE + node_name_suffix, 3, 1); MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(pack_node->GetOutDataAnchor(0), slice_node->GetInDataAnchor(1)), @@ -375,9 +376,7 @@ Status MdsUtils::ConstructSliceNode(const ComputeGraphPtr &src_graph, return SUCCESS; } -NodePtr MdsUtils::AddSingleInputOutputNode(const ComputeGraphPtr &graph, - const string &name, - const string &type, +NodePtr MdsUtils::AddSingleInputOutputNode(const ComputeGraphPtr &graph, const string &name, const string &type, const GeTensorDesc &tensor) { GELOGI("Begin to create op: %s", name.c_str()); OpDescBuilder op_desc_builder(name, type); @@ -389,20 +388,17 @@ NodePtr MdsUtils::AddSingleInputOutputNode(const ComputeGraphPtr &graph, } NodePtr node = graph->AddNode(op_desc); if (node == nullptr) { - REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", - op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); - GELOGE(FAILED, "[Add][Node] %s(%s) to graph:%s failed", - op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", op_desc->GetName().c_str(), + op_desc->GetType().c_str(), graph->GetName().c_str()); + GELOGE(FAILED, "[Add][Node] %s(%s) to graph:%s failed", op_desc->GetName().c_str(), op_desc->GetType().c_str(), + graph->GetName().c_str()); return nullptr; } return node; } -NodePtr MdsUtils::AddDynamicInputOutputNode(const ComputeGraphPtr &graph, - const std::string &type, - const std::string &node_name, - size_t input_num, - size_t output_num) { +NodePtr MdsUtils::AddDynamicInputOutputNode(const ComputeGraphPtr &graph, const std::string &type, + const std::string &node_name, size_t input_num, size_t output_num) { GELOGI("Begin to create op: %s", node_name.c_str()); OpDescBuilder op_desc_builder(node_name, type); OpDescPtr op_desc = op_desc_builder.AddDynamicInput("x", input_num).AddDynamicOutput("y", output_num).Build(); @@ -413,10 +409,10 @@ NodePtr MdsUtils::AddDynamicInputOutputNode(const ComputeGraphPtr &graph, } NodePtr node = graph->AddNode(op_desc); if (node == nullptr) { - REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", - op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); - GELOGE(FAILED, "[Add][Node] %s(%s) to graph:%s failed", - op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", op_desc->GetName().c_str(), + op_desc->GetType().c_str(), graph->GetName().c_str()); + GELOGE(FAILED, "[Add][Node] %s(%s) to graph:%s failed", op_desc->GetName().c_str(), op_desc->GetType().c_str(), + graph->GetName().c_str()); return nullptr; } return node; @@ -437,27 +433,20 @@ NodePtr MdsUtils::AddConstNodeToGraph(GeTensorPtr &tensor, const ComputeGraphPtr return graph->AddNodeFront(const_desc); } -Status MdsUtils::ConstructReduceNode(const ComputeGraphPtr &src_graph, - const OutDataAnchorPtr &src, - const InDataAnchorPtr &dst, - NodePtr &reduce_node) { +Status MdsUtils::ConstructReduceNode(const ComputeGraphPtr &src_graph, const OutDataAnchorPtr &src, + const InDataAnchorPtr &dst, NodePtr &reduce_node) { std::string node_name_suffix("_" + kPrefix + "_" + std::to_string(data_reduce_count)); reduce_node = AddDynamicInputOutputNode(src_graph, HCOMALLREDUCE, HCOMALLREDUCE + node_name_suffix, 1, 1); MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(src, {dst}, reduce_node), "[DataReduce] failed insert %s between %s and %s", reduce_node->GetName().c_str(), - src->GetOwnerNode()->GetName().c_str(), - dst->GetOwnerNode()->GetName().c_str()); + src->GetOwnerNode()->GetName().c_str(), dst->GetOwnerNode()->GetName().c_str()); MDS_REQUIRE_SUCCESS(MdsUtils::SetAttrForHcomNode(reduce_node->GetOpDesc(), kDeployNumber, kDefaultGroup), - "[DataReduce][Create] set attr for allreduce node for %s failed", - reduce_node->GetName().c_str()); + "[DataReduce][Create] set attr for allreduce node for %s failed", reduce_node->GetName().c_str()); REQUIRE(ge::AttrUtils::SetStr(reduce_node->GetOpDesc(), HCOM_ATTR_REDUCE_TYPE, kDefaultReduction), - "Failed to set attr reduction type %s for op:%s(%s)", - kDefaultReduction.c_str(), - reduce_node->GetName().c_str(), - reduce_node->GetType().c_str()); - MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(reduce_node, false), - "[DataReduce] %s call infershape failed", + "Failed to set attr reduction type %s for op:%s(%s)", kDefaultReduction.c_str(), + reduce_node->GetName().c_str(), reduce_node->GetType().c_str()); + MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(reduce_node, false), "[DataReduce] %s call infershape failed", reduce_node->GetName().c_str()); auto div_node = AddDynamicInputOutputNode(src_graph, REALDIV, REALDIV + node_name_suffix, 2, 1); vector slice_sizes{kDeployNumber}; @@ -471,20 +460,17 @@ Status MdsUtils::ConstructReduceNode(const ComputeGraphPtr &src_graph, MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(const_node_div_input->GetOutDataAnchor(0), div_node->GetInDataAnchor(1)), "[ConstructSliceNode] add edge failed"); MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(reduce_node->GetOutDataAnchor(0), {dst}, div_node), - "[DataReduce] failed insert %s between %s and %s", - div_node->GetName().c_str(), - reduce_node->GetName().c_str(), - dst->GetOwnerNode()->GetName().c_str()); - MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(div_node, false), - "[DataReduce] %s call infershape failed", + "[DataReduce] failed insert %s between %s and %s", div_node->GetName().c_str(), + reduce_node->GetName().c_str(), dst->GetOwnerNode()->GetName().c_str()); + MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(div_node, false), "[DataReduce] %s call infershape failed", div_node->GetName().c_str()); return SUCCESS; } bool MdsUtils::NeedInsertHcomAllReduce(const NodePtr &src_node, NodePtr &allreduce_node) { - //TODO: 识别到网络中本来就是多p的模型,即已经有了allreduce节点,则无需插入 + // TODO: recognize that the graph is originally a multi-p model, that is, there is already an allreduce node, + // so there is no need to insert i return true; } -} - +} // namespace ge diff --git a/ge/graph/passes/mds_kernels/mds_utils.h b/ge/graph/passes/mds_kernels/mds_utils.h index 72c2564d..d5d9d430 100644 --- a/ge/graph/passes/mds_kernels/mds_utils.h +++ b/ge/graph/passes/mds_kernels/mds_utils.h @@ -30,13 +30,13 @@ #include "graph/utils/op_desc_utils.h" #include "../pass_utils.h" -#define REQUIRE(cond, ...) \ - do { \ - if (!(cond)) { \ - REPORT_INNER_ERROR("E19999", __VA_ARGS__); \ - GELOGE(FAILED, "[MDS]" __VA_ARGS__); \ - return FAILED; \ - } \ +#define REQUIRE(cond, ...) \ + do { \ + if (!(cond)) { \ + REPORT_INNER_ERROR("E19999", __VA_ARGS__); \ + GELOGE(FAILED, "[MDS]" __VA_ARGS__); \ + return FAILED; \ + } \ } while (0) #define MDS_REQUIRE_NOT_NULL(cond, ...) REQUIRE(((cond) != nullptr), __VA_ARGS__) @@ -44,7 +44,7 @@ #define MDS_REQUIRE_GRAPH_SUCCESS(cond, ...) REQUIRE(((cond) == GRAPH_SUCCESS), __VA_ARGS__) namespace ge { namespace { -//Invalid location index +// Invalid location index const int64_t kInvalidIndex = -1; enum NCutIndex { kNLocation0 = 0, kNLocation1, kNLocation2, kNLocation3, kNInvalidLocation = -1 }; enum HCutIndex { kHLocation0 = 0, kHLocation1, kHLocation2, kHLocation3, kHInvalidLocation = -1 }; @@ -69,10 +69,9 @@ const std::string kDefaultReduction = "sum"; const char *const kDefaultDeviceType = "DEFAULT_DEVICE_TYPE"; const char *const kDefaultExecUnit = "DEFAULT_DEVICE_TYPE"; -//deploy info +// deploy info const char *const kAttrNeedReturnResult = "_need_return_result"; const char *const kAttrDeviceType = "_device_type"; -// TODO:跟rts确认属性值 const char *const kDieDeviceTypeValue = "MultiMode"; const char *const kAttrDeviceId = "_device_id"; const char *const kAttrGraphName = "_graph_name"; @@ -80,7 +79,7 @@ const char *const kAttrGraphInputs = "_graph_inputs"; using GraphInputs = vector; using DeviceId = int64_t; using GraphInputNodes = vector; -} +} // namespace class MdsUtils { public: // Parse the configuration file and determine whether to enable MDS based on the value of device_type. @@ -90,24 +89,25 @@ class MdsUtils { static int64_t GetIndexByFormat(const GeTensorDescPtr &ge_tensor_desc, CutType type); static bool IsDistributedDeploySupported(const GeTensorDescPtr &ge_tensor_desc, CutType type); - static Status SetAttrForHcomNode(const OpDescPtr &hcom_op, - int64_t fission_factor, + static Status SetAttrForHcomNode(const OpDescPtr &hcom_op, int64_t fission_factor, const std::string &group_name = ""); /// @param [in] index 切分的轴 /// @param [in] deploy_number 切分的份数 - static Status DistributedDeploy(const GeTensorDescPtr &ge_tensor_desc, - CutType type, + static Status DistributedDeploy(const GeTensorDescPtr &ge_tensor_desc, CutType type, int64_t deploy_number = kDeployNumber); // Sets the information, notifies the number of threads to be started during the // loading phase, the device on which each thread should run, and constructs different input data on each device. static Status SetDeployInfo(const ComputeGraphPtr &compute_graph, const NodePtr &input_node); - static Status SetDeployInfo(const ComputeGraphPtr &compute_graph, - const std::multimap &deploys, + static Status SetDeployInfo(const ComputeGraphPtr &compute_graph, const std::multimap &deploys, const std::string &device_type = kDieDeviceTypeValue); // Get cut policy for whole graph static CutType TryGetGraphCutType(const ComputeGraphPtr &compute_graph); - static GraphInputNodes GetInputNodes() { return input_nodes_; } - static void AddInputNode(const NodePtr &input_node) { input_nodes_.push_back(input_node); } + static GraphInputNodes GetInputNodes() { + return input_nodes_; + } + static void AddInputNode(const NodePtr &input_node) { + input_nodes_.push_back(input_node); + } static Status DataGather(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst); static Status DataReduce(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst); @@ -115,25 +115,16 @@ class MdsUtils { private: static GraphInputNodes input_nodes_; - static NodePtr AddDynamicInputOutputNode(const ComputeGraphPtr &graph, - const string &type, - const string &node_name, - size_t input_num, - size_t output_num); - static NodePtr AddSingleInputOutputNode(const ComputeGraphPtr &graph, - const string &name, - const string &type, + static NodePtr AddDynamicInputOutputNode(const ComputeGraphPtr &graph, const string &type, const string &node_name, + size_t input_num, size_t output_num); + static NodePtr AddSingleInputOutputNode(const ComputeGraphPtr &graph, const string &name, const string &type, const GeTensorDesc &tensor = GeTensorDesc()); - static Status ConstructReduceNode(const ComputeGraphPtr &src_graph, - const OutDataAnchorPtr &src, - const InDataAnchorPtr &dst, - NodePtr &reduce_node); - static Status ConstructSliceNode(const ComputeGraphPtr &src_graph, - const GeTensorDesc &tensor, - Node *node, + static Status ConstructReduceNode(const ComputeGraphPtr &src_graph, const OutDataAnchorPtr &src, + const InDataAnchorPtr &dst, NodePtr &reduce_node); + static Status ConstructSliceNode(const ComputeGraphPtr &src_graph, const GeTensorDesc &tensor, Node *node, NodePtr &slice_node); static bool NeedInsertHcomAllReduce(const NodePtr &src_node, NodePtr &allreduce_node); static NodePtr AddConstNodeToGraph(GeTensorPtr &tensor, const ComputeGraphPtr &graph); }; -} -#endif //MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_UTILS_H_ +} // namespace ge +#endif // MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_UTILS_H_ diff --git a/ge/graph/passes/mds_pass.cc b/ge/graph/passes/mds_pass.cc index 89afcb88..aa411ed5 100644 --- a/ge/graph/passes/mds_pass.cc +++ b/ge/graph/passes/mds_pass.cc @@ -43,30 +43,30 @@ Status ModelDeploySchedulerPass::CutProcess() { } auto type = MdsUtils::TryGetGraphCutType(compute_graph_); switch (type) { - case kCutN:MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph_), - "[MDS][CutNProcessImply] failed, graph_name:[%s]", - GetGraphName()); + case kCutN: + MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph_), "[MDS][CutNProcessImply] failed, graph_name:[%s]", + GetGraphName()); break; - case kCutH:MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph_), - "[MDS][CutHProcessImply] failed, graph_name:[%s]", - GetGraphName()); + case kCutH: + MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph_), "[MDS][CutHProcessImply] failed, graph_name:[%s]", + GetGraphName()); break; - case kDynamicCutN:MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph_, true), - "[MDS][CutNProcessImply] failed, graph_name:[%s]", - GetGraphName()); + case kDynamicCutN: + MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph_, true), "[MDS][CutNProcessImply] failed, graph_name:[%s]", + GetGraphName()); break; - case kDynamicCutH:MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph_, true), - "[MDS][CutHProcessImply] failed, graph_name:[%s]", - GetGraphName()); + case kDynamicCutH: + MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph_, true), "[MDS][CutHProcessImply] failed, graph_name:[%s]", + GetGraphName()); break; - case kDynamicCutAll:MDS_REQUIRE_SUCCESS(DynamicCutAll(compute_graph_), - "[MDS][DynamicCutAll] failed, graph_name:[%s]", - GetGraphName()); + case kDynamicCutAll: + MDS_REQUIRE_SUCCESS(DynamicCutAll(compute_graph_), "[MDS][DynamicCutAll] failed, graph_name:[%s]", + GetGraphName()); break; - default:GELOGI("[MDS][CutProcess] could not cut, just return. graph_name:[%s]", GetGraphName()); + default: + GELOGI("[MDS][CutProcess] could not cut, just return. graph_name:[%s]", GetGraphName()); return SUCCESS; } - } Status ModelDeploySchedulerPass::CutNProcessImply(const ComputeGraphPtr &compute_graph, bool is_dynamic) { @@ -78,21 +78,19 @@ Status ModelDeploySchedulerPass::CutNProcessImply(const ComputeGraphPtr &compute op_kernel = DeploySchedulerKernel::Instance(); } if (is_dynamic) { - MDS_REQUIRE_SUCCESS(op_kernel->DynamicCutN(node), - "[MDS][DYNAMIC_CUTN] failed, node:[%s]", + MDS_REQUIRE_SUCCESS(op_kernel->DynamicCutN(node), "[MDS][DYNAMIC_CUTN] failed, node:[%s]", node->GetName().c_str()); } else { MDS_REQUIRE_SUCCESS(op_kernel->CutN(node), "[MDS][CUTN] failed, node:[%s]", node->GetName().c_str()); } bool is_grad_compute_node = false; - if (ge::AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_GRADIENT_NODE, is_grad_compute_node) - && is_grad_compute_node) { + if (ge::AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_GRADIENT_NODE, is_grad_compute_node) && + is_grad_compute_node) { grad_compute_nodes_.push_back(node); } } - //TODO:针对单输出多引用插入的allgather,allreduce节点做广度融合优化 - MDS_REQUIRE_SUCCESS(HcomNodeFusionProcess(), - "[MDS][CUTN][HcomNodeFusionProcess] failed, compute graph:[%s]", + // TODO:for single output multi reference insertion allgather, allreduce nodes, do breadth fusion optimization + MDS_REQUIRE_SUCCESS(HcomNodeFusionProcess(), "[MDS][CUTN][HcomNodeFusionProcess] failed, compute graph:[%s]", compute_graph->GetName().c_str()); return SUCCESS; } @@ -105,12 +103,10 @@ Status ModelDeploySchedulerPass::CutHProcessImply(const ComputeGraphPtr &compute op_kernel = DeploySchedulerKernel::Instance(); } if (is_dynamic) { - MDS_REQUIRE_SUCCESS(op_kernel->DynamicCutH(node), - "[MDS][DYNAMIC_CUTH] failed, node:[%s]", + MDS_REQUIRE_SUCCESS(op_kernel->DynamicCutH(node), "[MDS][DYNAMIC_CUTH] failed, node:[%s]", node->GetName().c_str()); } else { MDS_REQUIRE_SUCCESS(op_kernel->CutH(node), "[MDS][CUTH] failed, node:[%s]", node->GetName().c_str()); - } } return SUCCESS; @@ -121,13 +117,11 @@ Status ModelDeploySchedulerPass::DynamicCutAll(const ComputeGraphPtr &compute_gr std::vector output_nodes; auto compute_graph0 = GraphUtils::CloneGraph(compute_graph, "", input_nodes, output_nodes); auto compute_graph1 = GraphUtils::CloneGraph(compute_graph, "", input_nodes, output_nodes); - MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph0, true), - "[MDS][CutNProcessImply] failed, graph_name:[%s]", + MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph0, true), "[MDS][CutNProcessImply] failed, graph_name:[%s]", compute_graph0->GetName().c_str()); - MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph1, true), - "[MDS][CutHProcessImply] failed, graph_name:[%s]", + MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph1, true), "[MDS][CutHProcessImply] failed, graph_name:[%s]", compute_graph1->GetName().c_str()); -//TODO:创建case节点,把两个图放在case的两个分支下,case节点添加到原来的compute_graph中,构造case节点的输入 + // TODO:Create a case node, put the two graphs under the two branches of case return SUCCESS; } @@ -143,9 +137,8 @@ Status ModelDeploySchedulerPass::SMDPProcess(bool before_cut) { Status ModelDeploySchedulerPass::SetDeployInfo() { vector deployInfo; - REQUIRE (!ge::AttrUtils::GetListNamedAttrs(compute_graph_, ATTR_NAME_DEPLOY_INFO, deployInfo), - "%s already has deployed before!", - GetGraphName()); + REQUIRE(!ge::AttrUtils::GetListNamedAttrs(compute_graph_, ATTR_NAME_DEPLOY_INFO, deployInfo), + "%s already has deployed before!", GetGraphName()); std::multimap deploys; for (int64_t j = 0; j < kDeployNumber; j++) { int64_t device_id = j; @@ -179,7 +172,6 @@ Status ModelDeploySchedulerPass::SMDPWeight() { return SUCCESS; } Status ModelDeploySchedulerPass::SMDPGradient() { - //TDOD:标识buffer poolid return SUCCESS; } -} +} // namespace ge diff --git a/ge/graph/passes/mds_pass.h b/ge/graph/passes/mds_pass.h index dfe4c0f6..4a18cc1e 100644 --- a/ge/graph/passes/mds_pass.h +++ b/ge/graph/passes/mds_pass.h @@ -28,6 +28,7 @@ namespace ge { class ModelDeploySchedulerPass : public GraphPass { public: Status Run(ge::ComputeGraphPtr graph) override; + private: // Part0:Process Func // cut and dynamic cut @@ -47,22 +48,24 @@ class ModelDeploySchedulerPass : public GraphPass { // set delpoyinfo Status SetDeployInfo(); - //Part1: Utils Func -// std::vector GetNodeInputsSupportCut(NodePtr node, uint64_t cut_index); -// std::vector GetNodeOutputsSupportCut(NodePtr node, uint64_t cut_index); + // Part1: Utils Func + // std::vector GetNodeInputsSupportCut(NodePtr node, uint64_t cut_index); + // std::vector GetNodeOutputsSupportCut(NodePtr node, uint64_t cut_index); Status HcomNodeFusionProcess(); Status GetAllModelStateVar(); Status GetAllWeightVar(); - std::vector GetAllGradComputeNodes() { return grad_compute_nodes_; } + std::vector GetAllGradComputeNodes() { + return grad_compute_nodes_; + } const char *GetGraphName() const { return compute_graph_->GetName().c_str(); } - //members + // members std::vector model_state_vars_; std::vector model_weight_vars_; std::vector grad_compute_nodes_; ComputeGraphPtr compute_graph_ = nullptr; }; } // namespace ge -#endif //MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_H_ +#endif // MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_H_