@@ -166,6 +166,11 @@ set(EXECUTOR_SRC_LIST | |||
"graph/manager/util/debug.cc" | |||
#"graph/manager/util/hcom_util.cc" # Just for runner. | |||
"graph/passes/pass_utils.cc" | |||
"graph/passes/mds_pass.cc" | |||
"graph/passes/mds_kernels/mds_utils.cc" | |||
"graph/passes/mds_kernels/variable_mds_kernel.cc" | |||
"graph/passes/mds_kernels/conv2d_mds_kernel.cc" | |||
"graph/passes/mds_kernels/base_mds_kernel.cc" | |||
"host_kernels/add_kernel.cc" | |||
"host_kernels/broadcast_args_kernel.cc" | |||
"host_kernels/broadcast_gradient_args_kernel.cc" | |||
@@ -77,9 +77,11 @@ Status GeLocalOpsKernelInfoStore::DestroySession(const map<string, string> &sess | |||
return SUCCESS; | |||
} | |||
Status GeLocalOpsKernelInfoStore::SetCutSupportedInfo(const NodePtr &node) { | |||
//TODO:1,针对变量类型是否标识为可训练变量 | |||
//2,是否开启smdp1和3 | |||
//满足上述两点,在变量切分信息里面设置为当前变量节点可切 | |||
// TODO: | |||
// 1. Whether the variable type is identified as a trainable variable | |||
// 2, whether to turn on smdp1 and 3 | |||
// To meet the above two points, set the current variable | |||
// node to be tangent in the variable segmentation information | |||
return SUCCESS; | |||
} | |||
} // namespace ge_local | |||
@@ -540,14 +540,12 @@ Status GraphExecutor::SetCallback(uint32_t model_id, const GeRootModelPtr &ge_ro | |||
return SUCCESS; | |||
} | |||
Status GraphExecutor::AsyncMultiExecuteModel(const GeRootModelPtr &ge_root_model, const std::vector<ge::Tensor> &inputs, | |||
const RunAsyncCallback &callback) { | |||
const RunAsyncCallback &callback) { | |||
// get deploy number of model instance | |||
auto root_graph = ge_root_model->GetRootGraph(); | |||
vector<GeAttrValue::NAMED_ATTRS> deploy_info; | |||
if (!ge::AttrUtils::GetListNamedAttrs(root_graph, ATTR_NAME_DEPLOY_INFO, deploy_info) || deploy_info.empty()) { | |||
GELOGE(FAILED, | |||
"[AsyncMultiExecuteModel] graph %s has invalid deploy attr %s", | |||
root_graph->GetName().c_str(), | |||
GELOGE(FAILED, "[AsyncMultiExecuteModel] graph %s has invalid deploy attr %s", root_graph->GetName().c_str(), | |||
ATTR_NAME_DEPLOY_INFO.c_str()); | |||
return FAILED; | |||
} | |||
@@ -567,7 +565,7 @@ Status GraphExecutor::AsyncMultiExecuteModel(const GeRootModelPtr &ge_root_model | |||
std::vector<GeTensorPtr> graph_inputs; | |||
if (ge::AttrUtils::MutableListTensor(thread_instance, kAttrGraphInputs, graph_inputs)) { | |||
std::vector<ge::Tensor> graph_input_updated(inputs.begin(), inputs.end()); | |||
for (const auto &ge_tensor_ptr:graph_inputs) { | |||
for (const auto &ge_tensor_ptr : graph_inputs) { | |||
graph_input_updated.push_back(TensorAdapter::AsTensor(*ge_tensor_ptr)); | |||
} | |||
GraphExecutor graph_executor; | |||
@@ -575,12 +573,12 @@ Status GraphExecutor::AsyncMultiExecuteModel(const GeRootModelPtr &ge_root_model | |||
std::future<Status> f; | |||
bool need_return_result = false; | |||
if ((ge::AttrUtils::GetBool(thread_instance, kAttrNeedReturnResult, need_return_result) && need_return_result)) { | |||
f = executor.commit(execute_model_func, &graph_executor, ge_root_model, | |||
model_ids[i], graph_input_updated, callback); | |||
f = executor.commit(execute_model_func, &graph_executor, ge_root_model, model_ids[i], graph_input_updated, | |||
callback); | |||
} else { | |||
RunAsyncCallback callback_stub; | |||
f = executor.commit(execute_model_func, &graph_executor, ge_root_model, | |||
model_ids[i], graph_input_updated, callback_stub); | |||
f = executor.commit(execute_model_func, &graph_executor, ge_root_model, model_ids[i], graph_input_updated, | |||
callback_stub); | |||
} | |||
if (!f.valid()) { | |||
GELOGE(FAILED, "[Call][Commit] failed, Future is invalid"); | |||
@@ -600,8 +598,8 @@ Status GraphExecutor::AsyncMultiExecuteModel(const GeRootModelPtr &ge_root_model | |||
return SUCCESS; | |||
} | |||
Status GraphExecutor::AsyncExecuteModel(const GeRootModelPtr &ge_root_model, uint32_t model_id, const std::vector<ge::Tensor> &inputs, | |||
const RunAsyncCallback &callback) { | |||
Status GraphExecutor::AsyncExecuteModel(const GeRootModelPtr &ge_root_model, uint32_t model_id, | |||
const std::vector<ge::Tensor> &inputs, const RunAsyncCallback &callback) { | |||
if (model_id == kInvalidModelId) { | |||
GELOGE(INTERNAL_ERROR, "No valid model id."); | |||
return INTERNAL_ERROR; | |||
@@ -149,17 +149,14 @@ Status GraphLoader::LoadMultiModelOnline(const std::shared_ptr<ge::GeRootModel> | |||
auto root_graph = ge_root_model->GetRootGraph(); | |||
vector<GeAttrValue::NAMED_ATTRS> deploy_info; | |||
if (!ge::AttrUtils::GetListNamedAttrs(root_graph, ATTR_NAME_DEPLOY_INFO, deploy_info) || deploy_info.empty()) { | |||
GELOGE(FAILED, | |||
"[LoadMultiModelOnline] Load multi model failed, graph %s has invalid deploy attr %s", | |||
root_graph->GetName().c_str(), | |||
ATTR_NAME_DEPLOY_INFO.c_str()); | |||
GELOGE(FAILED, "[LoadMultiModelOnline] Load multi model failed, graph %s has invalid deploy attr %s", | |||
root_graph->GetName().c_str(), ATTR_NAME_DEPLOY_INFO.c_str()); | |||
return FAILED; | |||
} | |||
auto thread_instances_size = deploy_info.size(); | |||
auto device_id_fission_from = GetContext().DeviceId(); | |||
GELOGI("Graph %s need to load model %zu times, and fission from device %u.", | |||
root_graph->GetName().c_str(), thread_instances_size, | |||
device_id_fission_from); | |||
GELOGI("Graph %s need to load model %zu times, and fission from device %u.", root_graph->GetName().c_str(), | |||
thread_instances_size, device_id_fission_from); | |||
ThreadPool executor(thread_instances_size); | |||
std::vector<std::future<Status>> vector_future; | |||
GE_TIMESTAMP_START(LoadModelOnline); | |||
@@ -167,17 +164,16 @@ Status GraphLoader::LoadMultiModelOnline(const std::shared_ptr<ge::GeRootModel> | |||
auto thread_instance = deploy_info[i]; | |||
std::string device_type; | |||
ModelIdInfo model_id_info; | |||
//TODO: listener要区分同步异步 | |||
std::shared_ptr<ModelListener> listener; | |||
if (is_async) { | |||
listener = MakeShared<RunAsyncListener>(); | |||
GE_CHECK_NOTNULL(listener); | |||
} else { | |||
// TODO: GraphModelListener for sync | |||
} | |||
int64_t device_id_fissioned = kInvalidDieId; | |||
if (!ge::AttrUtils::GetInt(thread_instance, kAttrDeviceId, device_id_fissioned) | |||
|| device_id_fissioned == kInvalidDieId) { | |||
if (!ge::AttrUtils::GetInt(thread_instance, kAttrDeviceId, device_id_fissioned) || | |||
device_id_fissioned == kInvalidDieId) { | |||
REPORT_CALL_ERROR("E19999", "graph %s has invalid deploy attr %s", root_graph->GetName().c_str(), | |||
ATTR_NAME_DEPLOY_INFO.c_str()); | |||
GELOGE(GRAPH_FAILED, "[LoadMultiModelOnline] graph %s has invalid deploy attr %s", root_graph->GetName().c_str(), | |||
@@ -26,8 +26,7 @@ shared_ptr<DeploySchedulerKernel> GetKernelByType(const NodePtr &node) { | |||
string type = node->GetType(); | |||
if (type == FRAMEWORKOP) { | |||
if (!ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type)) { | |||
REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", | |||
ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE.c_str(), | |||
REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE.c_str(), | |||
node->GetName().c_str(), node->GetType().c_str()); | |||
return nullptr; | |||
} | |||
@@ -35,10 +34,10 @@ shared_ptr<DeploySchedulerKernel> GetKernelByType(const NodePtr &node) { | |||
return factory.Create(type); | |||
} | |||
} | |||
} // namespace mds_cut_pass | |||
shared_ptr<DeploySchedulerKernel> DeploySchedulerKernel::Instance() { | |||
static const std::shared_ptr<DeploySchedulerKernel> instance_ptr = | |||
shared_ptr<DeploySchedulerKernel>(new(std::nothrow) DeploySchedulerKernel()); | |||
shared_ptr<DeploySchedulerKernel>(new (std::nothrow) DeploySchedulerKernel()); | |||
return instance_ptr; | |||
} | |||
@@ -63,20 +62,15 @@ Status DeploySchedulerKernel::CutN(const ge::NodePtr &node) { | |||
if (MdsUtils::IsDistributedDeploySupported(tensor_desc, kCutN)) { | |||
tensor_desc->SetShape(src_tensor_desc->GetShape()); | |||
} else { | |||
MDS_REQUIRE_SUCCESS(MdsUtils::DataGather(src_anchor, in_anchor), | |||
"[CutN] failed to gather between node[%s][%d] to node[%s][%d]", | |||
src_op_desc->GetName().c_str(), | |||
src_anchor->GetIdx(), | |||
op_desc->GetName().c_str(), | |||
in_anchor->GetIdx()); | |||
MDS_REQUIRE_SUCCESS( | |||
MdsUtils::DataGather(src_anchor, in_anchor), "[CutN] failed to gather between node[%s][%d] to node[%s][%d]", | |||
src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), in_anchor->GetIdx()); | |||
} | |||
} else { | |||
if (MdsUtils::IsDistributedDeploySupported(tensor_desc, kCutN)) { | |||
MDS_REQUIRE_SUCCESS(MdsUtils::DataSlice(src_anchor, in_anchor, input_node_), | |||
"[CutN] failed to slice between node[%s][%d] to node[%s][%d]", | |||
src_op_desc->GetName().c_str(), | |||
src_anchor->GetIdx(), | |||
op_desc->GetName().c_str(), | |||
src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), | |||
in_anchor->GetIdx()); | |||
} else { | |||
tensor_desc->SetShape(src_tensor_desc->GetShape()); | |||
@@ -85,14 +79,11 @@ Status DeploySchedulerKernel::CutN(const ge::NodePtr &node) { | |||
// insert hcomallreduce for cutn | |||
bool is_grad_compute_node = false; | |||
if (ge::AttrUtils::GetBool(src_node->GetOpDesc(), ATTR_NAME_GRADIENT_NODE, is_grad_compute_node) | |||
&& is_grad_compute_node) { | |||
MDS_REQUIRE_SUCCESS(MdsUtils::DataReduce(src_anchor, in_anchor), | |||
"[CutN] failed to reduce between node[%s][%d] to node[%s][%d]", | |||
src_op_desc->GetName().c_str(), | |||
src_anchor->GetIdx(), | |||
op_desc->GetName().c_str(), | |||
in_anchor->GetIdx()); | |||
if (ge::AttrUtils::GetBool(src_node->GetOpDesc(), ATTR_NAME_GRADIENT_NODE, is_grad_compute_node) && | |||
is_grad_compute_node) { | |||
MDS_REQUIRE_SUCCESS( | |||
MdsUtils::DataReduce(src_anchor, in_anchor), "[CutN] failed to reduce between node[%s][%d] to node[%s][%d]", | |||
src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), in_anchor->GetIdx()); | |||
} | |||
} | |||
@@ -122,32 +113,23 @@ Status DeploySchedulerKernel::CutH(const ge::NodePtr &node) { | |||
if (MdsUtils::IsDistributedDeploySupported(tensor_desc, kCutH)) { | |||
MDS_REQUIRE_SUCCESS(HaloExchangeProcess(node, in_anchor->GetIdx()), | |||
"[CutH] failed to do overlap between node[%s][%d] to node[%s][%d]", | |||
src_op_desc->GetName().c_str(), | |||
src_anchor->GetIdx(), | |||
op_desc->GetName().c_str(), | |||
src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), | |||
in_anchor->GetIdx()); | |||
} else { | |||
MDS_REQUIRE_SUCCESS(MdsUtils::DataGather(src_anchor, in_anchor), | |||
"[CutH] failed to gather between node[%s][%d] to node[%s][%d]", | |||
src_op_desc->GetName().c_str(), | |||
src_anchor->GetIdx(), | |||
op_desc->GetName().c_str(), | |||
in_anchor->GetIdx()); | |||
MDS_REQUIRE_SUCCESS( | |||
MdsUtils::DataGather(src_anchor, in_anchor), "[CutH] failed to gather between node[%s][%d] to node[%s][%d]", | |||
src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), in_anchor->GetIdx()); | |||
} | |||
} else { | |||
if (MdsUtils::IsDistributedDeploySupported(tensor_desc, kCutH)) { | |||
MDS_REQUIRE_SUCCESS(MdsUtils::DataSlice(src_anchor, in_anchor, input_node_), | |||
"[CutH] failed to slice between node[%s][%d] to node[%s][%d]", | |||
src_op_desc->GetName().c_str(), | |||
src_anchor->GetIdx(), | |||
op_desc->GetName().c_str(), | |||
src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), | |||
in_anchor->GetIdx()); | |||
} else { | |||
MDS_REQUIRE_SUCCESS(HaloExchangeProcess(node, in_anchor->GetIdx(), true), | |||
"[CutH] failed to do overlap between node[%s][%d] to node[%s][%d]", | |||
src_op_desc->GetName().c_str(), | |||
src_anchor->GetIdx(), | |||
op_desc->GetName().c_str(), | |||
src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), | |||
in_anchor->GetIdx()); | |||
} | |||
} | |||
@@ -157,4 +139,4 @@ Status DeploySchedulerKernel::CutH(const ge::NodePtr &node) { | |||
MDS_REQUIRE_SUCCESS(node->InferShapeAndType(), "[CutH] call infer shape failed", node->GetName().c_str()); | |||
return SUCCESS; | |||
} | |||
} | |||
} // namespace ge |
@@ -55,12 +55,16 @@ class DeploySchedulerKernel { | |||
// halo exchange process | |||
Status HaloExchangeProcess(NodePtr node, int64_t index, bool local_slice = false); | |||
NodePtr GetInputNode() { return input_node_; } | |||
NodePtr GetInputNode() { | |||
return input_node_; | |||
} | |||
DeploySchedulerKernel &operator=(const DeploySchedulerKernel &kernel) = delete; | |||
DeploySchedulerKernel(const DeploySchedulerKernel &kernel) = delete; | |||
protected: | |||
DeploySchedulerKernel() = default; | |||
virtual ~DeploySchedulerKernel() = default; | |||
private: | |||
NodePtr input_node_ = nullptr; | |||
}; | |||
@@ -69,4 +73,4 @@ namespace mds_cut_pass { | |||
shared_ptr<DeploySchedulerKernel> GetKernelByType(const NodePtr &node); | |||
} | |||
} // namespace ge | |||
#endif //MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_BASE_MDS_KERNEL_H_ | |||
#endif // MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_BASE_MDS_KERNEL_H_ |
@@ -21,39 +21,48 @@ thread_local int64_t data_slice_count = 0; | |||
thread_local int64_t data_gather_count = 0; | |||
thread_local int64_t data_reduce_count = 0; | |||
const std::string kPrefix = "mds"; | |||
} | |||
} // namespace | |||
int64_t MdsUtils::GetNLocation(Format fmt) { | |||
int64_t loc = kNInvalidLocation; | |||
switch (fmt) { | |||
case FORMAT_NCHW : | |||
case FORMAT_NHWC :loc = kNLocation0; | |||
case FORMAT_NCHW: | |||
case FORMAT_NHWC: | |||
loc = kNLocation0; | |||
break; | |||
case FORMAT_CHWN: | |||
case FORMAT_HWCN:loc = kNLocation3; | |||
case FORMAT_HWCN: | |||
loc = kNLocation3; | |||
break; | |||
default:GELOGE(FAILED, "[MDS]unsupported format:%d %s", fmt, TypeUtils::FormatToSerialString(fmt).c_str()); | |||
default: | |||
GELOGE(FAILED, "[MDS]unsupported format:%d %s", fmt, TypeUtils::FormatToSerialString(fmt).c_str()); | |||
} | |||
return loc; | |||
} | |||
int64_t MdsUtils::GetHLocation(Format fmt) { | |||
int64_t loc = kHInvalidLocation; | |||
switch (fmt) { | |||
case FORMAT_HWCN :loc = kHLocation0; | |||
case FORMAT_HWCN: | |||
loc = kHLocation0; | |||
break; | |||
case FORMAT_NHWC : | |||
case FORMAT_CHWN :loc = kHLocation1; | |||
case FORMAT_NHWC: | |||
case FORMAT_CHWN: | |||
loc = kHLocation1; | |||
break; | |||
case FORMAT_NCHW :loc = kHLocation2; | |||
default :GELOGE(FAILED, "[MDS]unsupported format:%d %s", fmt, TypeUtils::FormatToSerialString(fmt).c_str()); | |||
case FORMAT_NCHW: | |||
loc = kHLocation2; | |||
default: | |||
GELOGE(FAILED, "[MDS]unsupported format:%d %s", fmt, TypeUtils::FormatToSerialString(fmt).c_str()); | |||
} | |||
return loc; | |||
} | |||
int64_t MdsUtils::GetIndexByFormat(const GeTensorDescPtr &ge_tensor_desc, CutType type) { | |||
Format fmt = ge_tensor_desc->GetFormat(); | |||
switch (type) { | |||
case kCutN :return GetNLocation(fmt); | |||
case kCutH :return GetHLocation(fmt); | |||
default :; | |||
case kCutN: | |||
return GetNLocation(fmt); | |||
case kCutH: | |||
return GetHLocation(fmt); | |||
default:; | |||
} | |||
GELOGE(FAILED, "[MDS]invalid CutType:%d", type); | |||
return kInvalidIndex; | |||
@@ -64,7 +73,7 @@ bool MdsUtils::IsDistributedDeploySupported(const GeTensorDescPtr &ge_tensor_des | |||
GELOGE(FAILED, "[MDS]invalid input param: tensor is null!"); | |||
return false; | |||
} | |||
if (type != kCutN || type != kCutH) { | |||
if (type != kCutN && type != kCutH) { | |||
REPORT_INNER_ERROR("E19999", "invalid CutType:%d", type); | |||
GELOGE(FAILED, "[MDS]invalid CutType:%d", type); | |||
return false; | |||
@@ -75,6 +84,18 @@ bool MdsUtils::IsDistributedDeploySupported(const GeTensorDescPtr &ge_tensor_des | |||
GELOGE(FAILED, "[MDS]", "invalid index param:%ld", cut_index); | |||
return false; | |||
} | |||
auto dims = ge_tensor_desc->GetShape().GetDims(); | |||
if (cut_index < 0 || cut_index >= dims.size()) { | |||
REPORT_INNER_ERROR("E19999", "cut_index %ld for CutType %d is out of range of dims size %zu", cut_index, type, | |||
dims.size()); | |||
GELOGE(FAILED, "[MDS]", "cut_index %ld for CutType %d is out of range of dims size %zu", cut_index, type, | |||
dims.size()); | |||
return false; | |||
} | |||
if (dims[cut_index] % kDeployNumber != 0) { | |||
GELOGW("[MDS] cut_index %ld for CutType %d with dim %ld can not deploy", cut_index, type, dims[cut_index]); | |||
return false; | |||
} | |||
vector<int64_t> cut_support_info; | |||
if (!(AttrUtils::GetListInt(*ge_tensor_desc, ATTR_NAME_CUT_INFO, cut_support_info))) { | |||
REPORT_INNER_ERROR("E19999", "call GetlistInt failed"); | |||
@@ -82,10 +103,10 @@ bool MdsUtils::IsDistributedDeploySupported(const GeTensorDescPtr &ge_tensor_des | |||
return false; | |||
} | |||
if (cut_index < 0 || cut_index >= cut_support_info.size()) { | |||
REPORT_INNER_ERROR("E19999", "cut_index %ld for CutType %d is out of range of cut_support_info size %zu", | |||
cut_index, type, cut_support_info.size()); | |||
GELOGE(FAILED, "[MDS]", "cut_index %ld for CutType %d is out of range of cut_support_info size %zu", | |||
cut_index, type, cut_support_info.size()); | |||
REPORT_INNER_ERROR("E19999", "cut_index %ld for CutType %d is out of range of cut_support_info size %zu", cut_index, | |||
type, cut_support_info.size()); | |||
GELOGE(FAILED, "[MDS]", "cut_index %ld for CutType %d is out of range of cut_support_info size %zu", cut_index, | |||
type, cut_support_info.size()); | |||
return false; | |||
} | |||
if (cut_support_info[cut_index] < kNotSupport || cut_support_info[cut_index] > kAnyCutSupported) { | |||
@@ -95,8 +116,7 @@ bool MdsUtils::IsDistributedDeploySupported(const GeTensorDescPtr &ge_tensor_des | |||
} | |||
return cut_support_info[cut_index] & kSplitCutSupported; | |||
} | |||
Status MdsUtils::DistributedDeploy(const GeTensorDescPtr &ge_tensor_desc, | |||
CutType type, int64_t deploy_number) { | |||
Status MdsUtils::DistributedDeploy(const GeTensorDescPtr &ge_tensor_desc, CutType type, int64_t deploy_number) { | |||
GE_CHECK_NOTNULL(ge_tensor_desc); | |||
auto index = MdsUtils::GetIndexByFormat(ge_tensor_desc, type); | |||
auto dims = ge_tensor_desc->GetShape().GetDims(); | |||
@@ -109,14 +129,10 @@ Status MdsUtils::DistributedDeploy(const GeTensorDescPtr &ge_tensor_desc, | |||
} | |||
Status MdsUtils::SetAttrForHcomNode(const OpDescPtr &hcom_op, int64_t fission_factor, const std::string &group_name) { | |||
GE_CHECK_NOTNULL(hcom_op); | |||
REQUIRE(fission_factor > kDefaultFissionFactor, | |||
"fission_factor %ld need be bigger than %ld", | |||
fission_factor, | |||
REQUIRE(fission_factor > kDefaultFissionFactor, "fission_factor %ld need be bigger than %ld", fission_factor, | |||
kDefaultFissionFactor); | |||
REQUIRE(ge::AttrUtils::SetInt(hcom_op, ATTR_NAME_FISSION_FACTOR, fission_factor), | |||
"Failed to set attr fission_factor %ld for op:%s(%s)", | |||
fission_factor, | |||
hcom_op->GetName().c_str(), | |||
"Failed to set attr fission_factor %ld for op:%s(%s)", fission_factor, hcom_op->GetName().c_str(), | |||
hcom_op->GetType().c_str()); | |||
if (!group_name.empty()) { | |||
@@ -131,7 +147,7 @@ bool MdsUtils::IsMDSNeeded() { | |||
GELOGI("[MDS]device type is %s, skip mds", device_type.c_str()); | |||
return false; | |||
} | |||
//TODO:解析系统的配置文件得到exe unit | |||
// TODO: Parse the configuration file of the system to get the sys_config_exe_unit | |||
std::string sys_config_exe_unit = "DIE"; | |||
return device_type != sys_config_exe_unit; | |||
} | |||
@@ -153,18 +169,17 @@ Status MdsUtils::SetDeployInfo(const ComputeGraphPtr &compute_graph, const NodeP | |||
GeAttrValue::NAMED_ATTRS thread_instance; | |||
thread_instance.SetName(std::to_string(device_id)); | |||
(void) thread_instance.SetAttr(kAttrDeviceId, GeAttrValue::CreateFrom<GeAttrValue::INT>(device_id)); | |||
// TODO:跟rts确认属性值 | |||
(void) thread_instance.SetAttr(kAttrDeviceType, GeAttrValue::CreateFrom<GeAttrValue::STR>("MultiMode")); | |||
(void) thread_instance.SetAttr(kAttrGraphName, GeAttrValue::CreateFrom<GeAttrValue::STR>(compute_graph->GetName())); | |||
(void) thread_instance.SetAttr(kAttrGraphInputs, GeAttrValue::CreateFrom<GeAttrValue::LIST_TENSOR>(graph_inputs)); | |||
(void)thread_instance.SetAttr(kAttrDeviceId, GeAttrValue::CreateFrom<GeAttrValue::INT>(device_id)); | |||
// TODO:Change to enumeration from RTS header file | |||
(void)thread_instance.SetAttr(kAttrDeviceType, GeAttrValue::CreateFrom<GeAttrValue::STR>("MultiMode")); | |||
(void)thread_instance.SetAttr(kAttrGraphName, GeAttrValue::CreateFrom<GeAttrValue::STR>(compute_graph->GetName())); | |||
(void)thread_instance.SetAttr(kAttrGraphInputs, GeAttrValue::CreateFrom<GeAttrValue::LIST_TENSOR>(graph_inputs)); | |||
deploy_info.emplace_back(thread_instance); | |||
GELOGD("[MDS]%s SetDeployInfo on device id: %d", compute_graph->GetName().c_str(), device_id); | |||
} | |||
// set deploy info | |||
REQUIRE(ge::AttrUtils::SetListNamedAttrs(*compute_graph, ATTR_NAME_DEPLOY_INFO, deploy_info), | |||
"Set attr failed for graph %s", | |||
compute_graph->GetName().c_str()); | |||
"Set attr failed for graph %s", compute_graph->GetName().c_str()); | |||
return SUCCESS; | |||
} | |||
@@ -176,7 +191,7 @@ CutType MdsUtils::TryGetGraphCutType(const ComputeGraphPtr &compute_graph) { | |||
is_unknown_graph = true; | |||
} | |||
CutType selected_cut_type = kNoCut; | |||
for (const auto &data: compute_graph->GetInputNodes()) { | |||
for (const auto &data : compute_graph->GetInputNodes()) { | |||
GELOGI("Get graph input %s %s", data->GetName().c_str(), data->GetType().c_str()); | |||
auto data_n_index = MdsUtils::GetIndexByFormat(data->GetOpDesc()->MutableOutputDesc(0), kCutN); | |||
auto data_n_dim = data->GetOpDesc()->GetOutputDesc(0).GetShape().GetDim(data_n_index); | |||
@@ -197,8 +212,7 @@ CutType MdsUtils::TryGetGraphCutType(const ComputeGraphPtr &compute_graph) { | |||
return selected_cut_type; | |||
} | |||
Status MdsUtils::SetDeployInfo(const ComputeGraphPtr &compute_graph, | |||
const std::multimap<DeviceId, GraphInputs> &deploys, | |||
const std::string &device_type) { | |||
const std::multimap<DeviceId, GraphInputs> &deploys, const std::string &device_type) { | |||
GE_CHECK_NOTNULL(compute_graph); | |||
GELOGD("[MDS]%s SetDeployInfo start", compute_graph->GetName().c_str()); | |||
@@ -208,19 +222,18 @@ Status MdsUtils::SetDeployInfo(const ComputeGraphPtr &compute_graph, | |||
int64_t device_id = pair.first; | |||
GeAttrValue::NAMED_ATTRS thread_instance; | |||
thread_instance.SetName(std::to_string(device_id)); | |||
(void) thread_instance.SetAttr(kAttrNeedReturnResult, | |||
GeAttrValue::CreateFrom<GeAttrValue::BOOL>(deploy_info.empty() ? true : false)); | |||
(void) thread_instance.SetAttr(kAttrDeviceId, GeAttrValue::CreateFrom<GeAttrValue::INT>(device_id)); | |||
(void) thread_instance.SetAttr(kAttrDeviceType, GeAttrValue::CreateFrom<GeAttrValue::STR>(device_type)); | |||
(void) thread_instance.SetAttr(kAttrGraphName, GeAttrValue::CreateFrom<GeAttrValue::STR>(compute_graph->GetName())); | |||
(void) thread_instance.SetAttr(kAttrGraphInputs, GeAttrValue::CreateFrom<GeAttrValue::LIST_TENSOR>(pair.second)); | |||
(void)thread_instance.SetAttr(kAttrNeedReturnResult, | |||
GeAttrValue::CreateFrom<GeAttrValue::BOOL>(deploy_info.empty() ? true : false)); | |||
(void)thread_instance.SetAttr(kAttrDeviceId, GeAttrValue::CreateFrom<GeAttrValue::INT>(device_id)); | |||
(void)thread_instance.SetAttr(kAttrDeviceType, GeAttrValue::CreateFrom<GeAttrValue::STR>(device_type)); | |||
(void)thread_instance.SetAttr(kAttrGraphName, GeAttrValue::CreateFrom<GeAttrValue::STR>(compute_graph->GetName())); | |||
(void)thread_instance.SetAttr(kAttrGraphInputs, GeAttrValue::CreateFrom<GeAttrValue::LIST_TENSOR>(pair.second)); | |||
deploy_info.emplace_back(thread_instance); | |||
GELOGD("[MDS]%s SetDeployInfo on device id: %d", compute_graph->GetName().c_str(), device_id); | |||
} | |||
// set deploy info | |||
REQUIRE(ge::AttrUtils::SetListNamedAttrs(*compute_graph, ATTR_NAME_DEPLOY_INFO, deploy_info), | |||
"Set attr failed for graph %s", | |||
compute_graph->GetName().c_str()); | |||
"Set attr failed for graph %s", compute_graph->GetName().c_str()); | |||
return SUCCESS; | |||
} | |||
@@ -233,24 +246,20 @@ Status MdsUtils::DataGather(const OutDataAnchorPtr &src, const InDataAnchorPtr & | |||
auto src_graph = src_node->GetOwnerComputeGraph(); | |||
GE_CHECK_NOTNULL(src_graph); | |||
std::string node_name_suffix("_" + kPrefix + "_" + std::to_string(data_gather_count)); | |||
auto | |||
hcom_allgather_node = AddDynamicInputOutputNode(src_graph, HCOMALLGATHER, HCOMALLGATHER + node_name_suffix, 1, 1); | |||
auto hcom_allgather_node = | |||
AddDynamicInputOutputNode(src_graph, HCOMALLGATHER, HCOMALLGATHER + node_name_suffix, 1, 1); | |||
GE_CHECK_NOTNULL(hcom_allgather_node); | |||
MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(src, {dst}, hcom_allgather_node), | |||
"[DataGather] failed between %s and %s", | |||
src_node->GetName().c_str(), | |||
"[DataGather] failed between %s and %s", src_node->GetName().c_str(), | |||
dst_node->GetName().c_str()); | |||
MDS_REQUIRE_SUCCESS(MdsUtils::SetAttrForHcomNode(hcom_allgather_node->GetOpDesc(), kDeployNumber, kDefaultGroup), | |||
"[DataGather]set attr for node for %s(%s) failed", | |||
hcom_allgather_node->GetName().c_str(), hcom_allgather_node->GetType().c_str()); | |||
"[DataGather]set attr for node for %s(%s) failed", hcom_allgather_node->GetName().c_str(), | |||
hcom_allgather_node->GetType().c_str()); | |||
REQUIRE(ge::AttrUtils::SetInt(hcom_allgather_node->GetOpDesc(), HCOM_ATTR_RANK_SIZE, kDefaultRankSize), | |||
"Failed to set attr reduction type %s for op:%s(%s)", | |||
kDefaultReduction.c_str(), | |||
hcom_allgather_node->GetName().c_str(), | |||
hcom_allgather_node->GetType().c_str()); | |||
"Failed to set attr reduction type %s for op:%s(%s)", kDefaultReduction.c_str(), | |||
hcom_allgather_node->GetName().c_str(), hcom_allgather_node->GetType().c_str()); | |||
MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(hcom_allgather_node, false), | |||
"[DataGather] %s call infershape failed", | |||
hcom_allgather_node->GetName().c_str()); | |||
"[DataGather] %s call infershape failed", hcom_allgather_node->GetName().c_str()); | |||
data_gather_count++; | |||
return SUCCESS; | |||
} | |||
@@ -269,8 +278,7 @@ Status MdsUtils::DataReduce(const OutDataAnchorPtr &src, const InDataAnchorPtr & | |||
NodePtr all_reduce_node = nullptr; | |||
if (NeedInsertHcomAllReduce(src_node, all_reduce_node)) { | |||
MDS_REQUIRE_SUCCESS(ConstructReduceNode(src_graph, src, dst, all_reduce_node), | |||
"[DataReduce] construct allreduce node for %s failed", | |||
all_reduce_node->GetName().c_str()); | |||
"[DataReduce] construct allreduce node for %s failed", all_reduce_node->GetName().c_str()); | |||
GE_CHECK_NOTNULL(all_reduce_node); | |||
} else { | |||
GE_CHECK_NOTNULL(all_reduce_node); | |||
@@ -300,25 +308,19 @@ Status MdsUtils::DataSlice(const OutDataAnchorPtr &src, const InDataAnchorPtr &d | |||
GeTensorDesc tensor = src_node->GetOpDesc()->GetOutputDesc(src->GetIdx()); | |||
NodePtr slice_node = nullptr; | |||
MDS_REQUIRE_SUCCESS(ConstructSliceNode(src_graph, tensor, input_node.get(), slice_node), | |||
"[DataSlice] construct slice node for %s failed", | |||
src_node->GetName().c_str()); | |||
"[DataSlice] construct slice node for %s failed", src_node->GetName().c_str()); | |||
GE_CHECK_NOTNULL(slice_node); | |||
MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(src, {dst}, slice_node), | |||
"[DataSlice] failed between %s and %s", | |||
src_node->GetName().c_str(), | |||
dst_node->GetName().c_str()); | |||
MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(slice_node, false), | |||
"[DataSlice] %s call infer shape failed", | |||
MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(src, {dst}, slice_node), "[DataSlice] failed between %s and %s", | |||
src_node->GetName().c_str(), dst_node->GetName().c_str()); | |||
MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(slice_node, false), "[DataSlice] %s call infer shape failed", | |||
slice_node->GetName().c_str()); | |||
return SUCCESS; | |||
} | |||
Status MdsUtils::ConstructSliceNode(const ComputeGraphPtr &src_graph, | |||
const GeTensorDesc &tensor, | |||
Node *input_node, | |||
Status MdsUtils::ConstructSliceNode(const ComputeGraphPtr &src_graph, const GeTensorDesc &tensor, Node *input_node, | |||
NodePtr &slice_node) { | |||
vector<int64_t> slice_sizes = tensor.GetShape().GetDims(); | |||
// TODO: 用图结构表达 | |||
// TODO: Express with graph structure | |||
slice_sizes[0] /= kDeployNumber; | |||
vector<GeTensorPtr> ge_tensors; | |||
GeTensorDesc ge_tensor_desc; | |||
@@ -346,10 +348,10 @@ Status MdsUtils::ConstructSliceNode(const ComputeGraphPtr &src_graph, | |||
GE_CHECK_NOTNULL(input_node); | |||
MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(input_node->GetOutDataAnchor(0), mul_node->GetInDataAnchor(0)), | |||
"[ConstructSliceNode] add edge failed"); | |||
MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(const_node_slice_offset_first_dim->GetOutDataAnchor(0), | |||
mul_node->GetInDataAnchor(1)), "[ConstructSliceNode] add edge failed"); | |||
MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(mul_node, false), | |||
"[DataSlice] %s call infer shape failed", | |||
MDS_REQUIRE_SUCCESS( | |||
GraphUtils::AddEdge(const_node_slice_offset_first_dim->GetOutDataAnchor(0), mul_node->GetInDataAnchor(1)), | |||
"[ConstructSliceNode] add edge failed"); | |||
MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(mul_node, false), "[DataSlice] %s call infer shape failed", | |||
mul_node->GetName().c_str()); | |||
NodePtr pack_node = AddDynamicInputOutputNode(src_graph, PACK, PACK + node_name_suffix, slice_sizes.size(), 1); | |||
bool is_first_input = true; | |||
@@ -363,8 +365,7 @@ Status MdsUtils::ConstructSliceNode(const ComputeGraphPtr &src_graph, | |||
"[ConstructSliceNode] add edge failed"); | |||
} | |||
} | |||
MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(pack_node, false), | |||
"[DataSlice] %s call infer shape failed", | |||
MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(pack_node, false), "[DataSlice] %s call infer shape failed", | |||
pack_node->GetName().c_str()); | |||
slice_node = AddDynamicInputOutputNode(src_graph, SLICE, SLICE + node_name_suffix, 3, 1); | |||
MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(pack_node->GetOutDataAnchor(0), slice_node->GetInDataAnchor(1)), | |||
@@ -375,9 +376,7 @@ Status MdsUtils::ConstructSliceNode(const ComputeGraphPtr &src_graph, | |||
return SUCCESS; | |||
} | |||
NodePtr MdsUtils::AddSingleInputOutputNode(const ComputeGraphPtr &graph, | |||
const string &name, | |||
const string &type, | |||
NodePtr MdsUtils::AddSingleInputOutputNode(const ComputeGraphPtr &graph, const string &name, const string &type, | |||
const GeTensorDesc &tensor) { | |||
GELOGI("Begin to create op: %s", name.c_str()); | |||
OpDescBuilder op_desc_builder(name, type); | |||
@@ -389,20 +388,17 @@ NodePtr MdsUtils::AddSingleInputOutputNode(const ComputeGraphPtr &graph, | |||
} | |||
NodePtr node = graph->AddNode(op_desc); | |||
if (node == nullptr) { | |||
REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); | |||
GELOGE(FAILED, "[Add][Node] %s(%s) to graph:%s failed", | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); | |||
REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", op_desc->GetName().c_str(), | |||
op_desc->GetType().c_str(), graph->GetName().c_str()); | |||
GELOGE(FAILED, "[Add][Node] %s(%s) to graph:%s failed", op_desc->GetName().c_str(), op_desc->GetType().c_str(), | |||
graph->GetName().c_str()); | |||
return nullptr; | |||
} | |||
return node; | |||
} | |||
NodePtr MdsUtils::AddDynamicInputOutputNode(const ComputeGraphPtr &graph, | |||
const std::string &type, | |||
const std::string &node_name, | |||
size_t input_num, | |||
size_t output_num) { | |||
NodePtr MdsUtils::AddDynamicInputOutputNode(const ComputeGraphPtr &graph, const std::string &type, | |||
const std::string &node_name, size_t input_num, size_t output_num) { | |||
GELOGI("Begin to create op: %s", node_name.c_str()); | |||
OpDescBuilder op_desc_builder(node_name, type); | |||
OpDescPtr op_desc = op_desc_builder.AddDynamicInput("x", input_num).AddDynamicOutput("y", output_num).Build(); | |||
@@ -413,10 +409,10 @@ NodePtr MdsUtils::AddDynamicInputOutputNode(const ComputeGraphPtr &graph, | |||
} | |||
NodePtr node = graph->AddNode(op_desc); | |||
if (node == nullptr) { | |||
REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); | |||
GELOGE(FAILED, "[Add][Node] %s(%s) to graph:%s failed", | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); | |||
REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", op_desc->GetName().c_str(), | |||
op_desc->GetType().c_str(), graph->GetName().c_str()); | |||
GELOGE(FAILED, "[Add][Node] %s(%s) to graph:%s failed", op_desc->GetName().c_str(), op_desc->GetType().c_str(), | |||
graph->GetName().c_str()); | |||
return nullptr; | |||
} | |||
return node; | |||
@@ -437,27 +433,20 @@ NodePtr MdsUtils::AddConstNodeToGraph(GeTensorPtr &tensor, const ComputeGraphPtr | |||
return graph->AddNodeFront(const_desc); | |||
} | |||
Status MdsUtils::ConstructReduceNode(const ComputeGraphPtr &src_graph, | |||
const OutDataAnchorPtr &src, | |||
const InDataAnchorPtr &dst, | |||
NodePtr &reduce_node) { | |||
Status MdsUtils::ConstructReduceNode(const ComputeGraphPtr &src_graph, const OutDataAnchorPtr &src, | |||
const InDataAnchorPtr &dst, NodePtr &reduce_node) { | |||
std::string node_name_suffix("_" + kPrefix + "_" + std::to_string(data_reduce_count)); | |||
reduce_node = AddDynamicInputOutputNode(src_graph, HCOMALLREDUCE, HCOMALLREDUCE + node_name_suffix, 1, 1); | |||
MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(src, {dst}, reduce_node), | |||
"[DataReduce] failed insert %s between %s and %s", reduce_node->GetName().c_str(), | |||
src->GetOwnerNode()->GetName().c_str(), | |||
dst->GetOwnerNode()->GetName().c_str()); | |||
src->GetOwnerNode()->GetName().c_str(), dst->GetOwnerNode()->GetName().c_str()); | |||
MDS_REQUIRE_SUCCESS(MdsUtils::SetAttrForHcomNode(reduce_node->GetOpDesc(), kDeployNumber, kDefaultGroup), | |||
"[DataReduce][Create] set attr for allreduce node for %s failed", | |||
reduce_node->GetName().c_str()); | |||
"[DataReduce][Create] set attr for allreduce node for %s failed", reduce_node->GetName().c_str()); | |||
REQUIRE(ge::AttrUtils::SetStr(reduce_node->GetOpDesc(), HCOM_ATTR_REDUCE_TYPE, kDefaultReduction), | |||
"Failed to set attr reduction type %s for op:%s(%s)", | |||
kDefaultReduction.c_str(), | |||
reduce_node->GetName().c_str(), | |||
reduce_node->GetType().c_str()); | |||
MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(reduce_node, false), | |||
"[DataReduce] %s call infershape failed", | |||
"Failed to set attr reduction type %s for op:%s(%s)", kDefaultReduction.c_str(), | |||
reduce_node->GetName().c_str(), reduce_node->GetType().c_str()); | |||
MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(reduce_node, false), "[DataReduce] %s call infershape failed", | |||
reduce_node->GetName().c_str()); | |||
auto div_node = AddDynamicInputOutputNode(src_graph, REALDIV, REALDIV + node_name_suffix, 2, 1); | |||
vector<int64_t> slice_sizes{kDeployNumber}; | |||
@@ -471,20 +460,17 @@ Status MdsUtils::ConstructReduceNode(const ComputeGraphPtr &src_graph, | |||
MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(const_node_div_input->GetOutDataAnchor(0), div_node->GetInDataAnchor(1)), | |||
"[ConstructSliceNode] add edge failed"); | |||
MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(reduce_node->GetOutDataAnchor(0), {dst}, div_node), | |||
"[DataReduce] failed insert %s between %s and %s", | |||
div_node->GetName().c_str(), | |||
reduce_node->GetName().c_str(), | |||
dst->GetOwnerNode()->GetName().c_str()); | |||
MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(div_node, false), | |||
"[DataReduce] %s call infershape failed", | |||
"[DataReduce] failed insert %s between %s and %s", div_node->GetName().c_str(), | |||
reduce_node->GetName().c_str(), dst->GetOwnerNode()->GetName().c_str()); | |||
MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(div_node, false), "[DataReduce] %s call infershape failed", | |||
div_node->GetName().c_str()); | |||
return SUCCESS; | |||
} | |||
bool MdsUtils::NeedInsertHcomAllReduce(const NodePtr &src_node, NodePtr &allreduce_node) { | |||
//TODO: 识别到网络中本来就是多p的模型,即已经有了allreduce节点,则无需插入 | |||
// TODO: recognize that the graph is originally a multi-p model, that is, there is already an allreduce node, | |||
// so there is no need to insert i | |||
return true; | |||
} | |||
} | |||
} // namespace ge |
@@ -30,13 +30,13 @@ | |||
#include "graph/utils/op_desc_utils.h" | |||
#include "../pass_utils.h" | |||
#define REQUIRE(cond, ...) \ | |||
do { \ | |||
if (!(cond)) { \ | |||
REPORT_INNER_ERROR("E19999", __VA_ARGS__); \ | |||
GELOGE(FAILED, "[MDS]" __VA_ARGS__); \ | |||
return FAILED; \ | |||
} \ | |||
#define REQUIRE(cond, ...) \ | |||
do { \ | |||
if (!(cond)) { \ | |||
REPORT_INNER_ERROR("E19999", __VA_ARGS__); \ | |||
GELOGE(FAILED, "[MDS]" __VA_ARGS__); \ | |||
return FAILED; \ | |||
} \ | |||
} while (0) | |||
#define MDS_REQUIRE_NOT_NULL(cond, ...) REQUIRE(((cond) != nullptr), __VA_ARGS__) | |||
@@ -44,7 +44,7 @@ | |||
#define MDS_REQUIRE_GRAPH_SUCCESS(cond, ...) REQUIRE(((cond) == GRAPH_SUCCESS), __VA_ARGS__) | |||
namespace ge { | |||
namespace { | |||
//Invalid location index | |||
// Invalid location index | |||
const int64_t kInvalidIndex = -1; | |||
enum NCutIndex { kNLocation0 = 0, kNLocation1, kNLocation2, kNLocation3, kNInvalidLocation = -1 }; | |||
enum HCutIndex { kHLocation0 = 0, kHLocation1, kHLocation2, kHLocation3, kHInvalidLocation = -1 }; | |||
@@ -69,10 +69,9 @@ const std::string kDefaultReduction = "sum"; | |||
const char *const kDefaultDeviceType = "DEFAULT_DEVICE_TYPE"; | |||
const char *const kDefaultExecUnit = "DEFAULT_DEVICE_TYPE"; | |||
//deploy info | |||
// deploy info | |||
const char *const kAttrNeedReturnResult = "_need_return_result"; | |||
const char *const kAttrDeviceType = "_device_type"; | |||
// TODO:跟rts确认属性值 | |||
const char *const kDieDeviceTypeValue = "MultiMode"; | |||
const char *const kAttrDeviceId = "_device_id"; | |||
const char *const kAttrGraphName = "_graph_name"; | |||
@@ -80,7 +79,7 @@ const char *const kAttrGraphInputs = "_graph_inputs"; | |||
using GraphInputs = vector<GeTensorPtr>; | |||
using DeviceId = int64_t; | |||
using GraphInputNodes = vector<NodePtr>; | |||
} | |||
} // namespace | |||
class MdsUtils { | |||
public: | |||
// Parse the configuration file and determine whether to enable MDS based on the value of device_type. | |||
@@ -90,24 +89,25 @@ class MdsUtils { | |||
static int64_t GetIndexByFormat(const GeTensorDescPtr &ge_tensor_desc, CutType type); | |||
static bool IsDistributedDeploySupported(const GeTensorDescPtr &ge_tensor_desc, CutType type); | |||
static Status SetAttrForHcomNode(const OpDescPtr &hcom_op, | |||
int64_t fission_factor, | |||
static Status SetAttrForHcomNode(const OpDescPtr &hcom_op, int64_t fission_factor, | |||
const std::string &group_name = ""); | |||
/// @param [in] index 切分的轴 | |||
/// @param [in] deploy_number 切分的份数 | |||
static Status DistributedDeploy(const GeTensorDescPtr &ge_tensor_desc, | |||
CutType type, | |||
static Status DistributedDeploy(const GeTensorDescPtr &ge_tensor_desc, CutType type, | |||
int64_t deploy_number = kDeployNumber); | |||
// Sets the information, notifies the number of threads to be started during the | |||
// loading phase, the device on which each thread should run, and constructs different input data on each device. | |||
static Status SetDeployInfo(const ComputeGraphPtr &compute_graph, const NodePtr &input_node); | |||
static Status SetDeployInfo(const ComputeGraphPtr &compute_graph, | |||
const std::multimap<DeviceId, GraphInputs> &deploys, | |||
static Status SetDeployInfo(const ComputeGraphPtr &compute_graph, const std::multimap<DeviceId, GraphInputs> &deploys, | |||
const std::string &device_type = kDieDeviceTypeValue); | |||
// Get cut policy for whole graph | |||
static CutType TryGetGraphCutType(const ComputeGraphPtr &compute_graph); | |||
static GraphInputNodes GetInputNodes() { return input_nodes_; } | |||
static void AddInputNode(const NodePtr &input_node) { input_nodes_.push_back(input_node); } | |||
static GraphInputNodes GetInputNodes() { | |||
return input_nodes_; | |||
} | |||
static void AddInputNode(const NodePtr &input_node) { | |||
input_nodes_.push_back(input_node); | |||
} | |||
static Status DataGather(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst); | |||
static Status DataReduce(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst); | |||
@@ -115,25 +115,16 @@ class MdsUtils { | |||
private: | |||
static GraphInputNodes input_nodes_; | |||
static NodePtr AddDynamicInputOutputNode(const ComputeGraphPtr &graph, | |||
const string &type, | |||
const string &node_name, | |||
size_t input_num, | |||
size_t output_num); | |||
static NodePtr AddSingleInputOutputNode(const ComputeGraphPtr &graph, | |||
const string &name, | |||
const string &type, | |||
static NodePtr AddDynamicInputOutputNode(const ComputeGraphPtr &graph, const string &type, const string &node_name, | |||
size_t input_num, size_t output_num); | |||
static NodePtr AddSingleInputOutputNode(const ComputeGraphPtr &graph, const string &name, const string &type, | |||
const GeTensorDesc &tensor = GeTensorDesc()); | |||
static Status ConstructReduceNode(const ComputeGraphPtr &src_graph, | |||
const OutDataAnchorPtr &src, | |||
const InDataAnchorPtr &dst, | |||
NodePtr &reduce_node); | |||
static Status ConstructSliceNode(const ComputeGraphPtr &src_graph, | |||
const GeTensorDesc &tensor, | |||
Node *node, | |||
static Status ConstructReduceNode(const ComputeGraphPtr &src_graph, const OutDataAnchorPtr &src, | |||
const InDataAnchorPtr &dst, NodePtr &reduce_node); | |||
static Status ConstructSliceNode(const ComputeGraphPtr &src_graph, const GeTensorDesc &tensor, Node *node, | |||
NodePtr &slice_node); | |||
static bool NeedInsertHcomAllReduce(const NodePtr &src_node, NodePtr &allreduce_node); | |||
static NodePtr AddConstNodeToGraph(GeTensorPtr &tensor, const ComputeGraphPtr &graph); | |||
}; | |||
} | |||
#endif //MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_UTILS_H_ | |||
} // namespace ge | |||
#endif // MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_UTILS_H_ |
@@ -43,30 +43,30 @@ Status ModelDeploySchedulerPass::CutProcess() { | |||
} | |||
auto type = MdsUtils::TryGetGraphCutType(compute_graph_); | |||
switch (type) { | |||
case kCutN:MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph_), | |||
"[MDS][CutNProcessImply] failed, graph_name:[%s]", | |||
GetGraphName()); | |||
case kCutN: | |||
MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph_), "[MDS][CutNProcessImply] failed, graph_name:[%s]", | |||
GetGraphName()); | |||
break; | |||
case kCutH:MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph_), | |||
"[MDS][CutHProcessImply] failed, graph_name:[%s]", | |||
GetGraphName()); | |||
case kCutH: | |||
MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph_), "[MDS][CutHProcessImply] failed, graph_name:[%s]", | |||
GetGraphName()); | |||
break; | |||
case kDynamicCutN:MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph_, true), | |||
"[MDS][CutNProcessImply] failed, graph_name:[%s]", | |||
GetGraphName()); | |||
case kDynamicCutN: | |||
MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph_, true), "[MDS][CutNProcessImply] failed, graph_name:[%s]", | |||
GetGraphName()); | |||
break; | |||
case kDynamicCutH:MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph_, true), | |||
"[MDS][CutHProcessImply] failed, graph_name:[%s]", | |||
GetGraphName()); | |||
case kDynamicCutH: | |||
MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph_, true), "[MDS][CutHProcessImply] failed, graph_name:[%s]", | |||
GetGraphName()); | |||
break; | |||
case kDynamicCutAll:MDS_REQUIRE_SUCCESS(DynamicCutAll(compute_graph_), | |||
"[MDS][DynamicCutAll] failed, graph_name:[%s]", | |||
GetGraphName()); | |||
case kDynamicCutAll: | |||
MDS_REQUIRE_SUCCESS(DynamicCutAll(compute_graph_), "[MDS][DynamicCutAll] failed, graph_name:[%s]", | |||
GetGraphName()); | |||
break; | |||
default:GELOGI("[MDS][CutProcess] could not cut, just return. graph_name:[%s]", GetGraphName()); | |||
default: | |||
GELOGI("[MDS][CutProcess] could not cut, just return. graph_name:[%s]", GetGraphName()); | |||
return SUCCESS; | |||
} | |||
} | |||
Status ModelDeploySchedulerPass::CutNProcessImply(const ComputeGraphPtr &compute_graph, bool is_dynamic) { | |||
@@ -78,21 +78,19 @@ Status ModelDeploySchedulerPass::CutNProcessImply(const ComputeGraphPtr &compute | |||
op_kernel = DeploySchedulerKernel::Instance(); | |||
} | |||
if (is_dynamic) { | |||
MDS_REQUIRE_SUCCESS(op_kernel->DynamicCutN(node), | |||
"[MDS][DYNAMIC_CUTN] failed, node:[%s]", | |||
MDS_REQUIRE_SUCCESS(op_kernel->DynamicCutN(node), "[MDS][DYNAMIC_CUTN] failed, node:[%s]", | |||
node->GetName().c_str()); | |||
} else { | |||
MDS_REQUIRE_SUCCESS(op_kernel->CutN(node), "[MDS][CUTN] failed, node:[%s]", node->GetName().c_str()); | |||
} | |||
bool is_grad_compute_node = false; | |||
if (ge::AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_GRADIENT_NODE, is_grad_compute_node) | |||
&& is_grad_compute_node) { | |||
if (ge::AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_GRADIENT_NODE, is_grad_compute_node) && | |||
is_grad_compute_node) { | |||
grad_compute_nodes_.push_back(node); | |||
} | |||
} | |||
//TODO:针对单输出多引用插入的allgather,allreduce节点做广度融合优化 | |||
MDS_REQUIRE_SUCCESS(HcomNodeFusionProcess(), | |||
"[MDS][CUTN][HcomNodeFusionProcess] failed, compute graph:[%s]", | |||
// TODO:for single output multi reference insertion allgather, allreduce nodes, do breadth fusion optimization | |||
MDS_REQUIRE_SUCCESS(HcomNodeFusionProcess(), "[MDS][CUTN][HcomNodeFusionProcess] failed, compute graph:[%s]", | |||
compute_graph->GetName().c_str()); | |||
return SUCCESS; | |||
} | |||
@@ -105,12 +103,10 @@ Status ModelDeploySchedulerPass::CutHProcessImply(const ComputeGraphPtr &compute | |||
op_kernel = DeploySchedulerKernel::Instance(); | |||
} | |||
if (is_dynamic) { | |||
MDS_REQUIRE_SUCCESS(op_kernel->DynamicCutH(node), | |||
"[MDS][DYNAMIC_CUTH] failed, node:[%s]", | |||
MDS_REQUIRE_SUCCESS(op_kernel->DynamicCutH(node), "[MDS][DYNAMIC_CUTH] failed, node:[%s]", | |||
node->GetName().c_str()); | |||
} else { | |||
MDS_REQUIRE_SUCCESS(op_kernel->CutH(node), "[MDS][CUTH] failed, node:[%s]", node->GetName().c_str()); | |||
} | |||
} | |||
return SUCCESS; | |||
@@ -121,13 +117,11 @@ Status ModelDeploySchedulerPass::DynamicCutAll(const ComputeGraphPtr &compute_gr | |||
std::vector<NodePtr> output_nodes; | |||
auto compute_graph0 = GraphUtils::CloneGraph(compute_graph, "", input_nodes, output_nodes); | |||
auto compute_graph1 = GraphUtils::CloneGraph(compute_graph, "", input_nodes, output_nodes); | |||
MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph0, true), | |||
"[MDS][CutNProcessImply] failed, graph_name:[%s]", | |||
MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph0, true), "[MDS][CutNProcessImply] failed, graph_name:[%s]", | |||
compute_graph0->GetName().c_str()); | |||
MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph1, true), | |||
"[MDS][CutHProcessImply] failed, graph_name:[%s]", | |||
MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph1, true), "[MDS][CutHProcessImply] failed, graph_name:[%s]", | |||
compute_graph1->GetName().c_str()); | |||
//TODO:创建case节点,把两个图放在case的两个分支下,case节点添加到原来的compute_graph中,构造case节点的输入 | |||
// TODO:Create a case node, put the two graphs under the two branches of case | |||
return SUCCESS; | |||
} | |||
@@ -143,9 +137,8 @@ Status ModelDeploySchedulerPass::SMDPProcess(bool before_cut) { | |||
Status ModelDeploySchedulerPass::SetDeployInfo() { | |||
vector<GeAttrValue::NAMED_ATTRS> deployInfo; | |||
REQUIRE (!ge::AttrUtils::GetListNamedAttrs(compute_graph_, ATTR_NAME_DEPLOY_INFO, deployInfo), | |||
"%s already has deployed before!", | |||
GetGraphName()); | |||
REQUIRE(!ge::AttrUtils::GetListNamedAttrs(compute_graph_, ATTR_NAME_DEPLOY_INFO, deployInfo), | |||
"%s already has deployed before!", GetGraphName()); | |||
std::multimap<DeviceId, GraphInputs> deploys; | |||
for (int64_t j = 0; j < kDeployNumber; j++) { | |||
int64_t device_id = j; | |||
@@ -179,7 +172,6 @@ Status ModelDeploySchedulerPass::SMDPWeight() { | |||
return SUCCESS; | |||
} | |||
Status ModelDeploySchedulerPass::SMDPGradient() { | |||
//TDOD:标识buffer poolid | |||
return SUCCESS; | |||
} | |||
} | |||
} // namespace ge |
@@ -28,6 +28,7 @@ namespace ge { | |||
class ModelDeploySchedulerPass : public GraphPass { | |||
public: | |||
Status Run(ge::ComputeGraphPtr graph) override; | |||
private: | |||
// Part0:Process Func | |||
// cut and dynamic cut | |||
@@ -47,22 +48,24 @@ class ModelDeploySchedulerPass : public GraphPass { | |||
// set delpoyinfo | |||
Status SetDeployInfo(); | |||
//Part1: Utils Func | |||
// std::vector<bool> GetNodeInputsSupportCut(NodePtr node, uint64_t cut_index); | |||
// std::vector<bool> GetNodeOutputsSupportCut(NodePtr node, uint64_t cut_index); | |||
// Part1: Utils Func | |||
// std::vector<bool> GetNodeInputsSupportCut(NodePtr node, uint64_t cut_index); | |||
// std::vector<bool> GetNodeOutputsSupportCut(NodePtr node, uint64_t cut_index); | |||
Status HcomNodeFusionProcess(); | |||
Status GetAllModelStateVar(); | |||
Status GetAllWeightVar(); | |||
std::vector<NodePtr> GetAllGradComputeNodes() { return grad_compute_nodes_; } | |||
std::vector<NodePtr> GetAllGradComputeNodes() { | |||
return grad_compute_nodes_; | |||
} | |||
const char *GetGraphName() const { | |||
return compute_graph_->GetName().c_str(); | |||
} | |||
//members | |||
// members | |||
std::vector<NodePtr> model_state_vars_; | |||
std::vector<NodePtr> model_weight_vars_; | |||
std::vector<NodePtr> grad_compute_nodes_; | |||
ComputeGraphPtr compute_graph_ = nullptr; | |||
}; | |||
} // namespace ge | |||
#endif //MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_H_ | |||
#endif // MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_H_ |