diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index f98297d8..7b9950b3 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -168,6 +168,11 @@ set(EXECUTOR_SRC_LIST "graph/manager/util/debug.cc" #"graph/manager/util/hcom_util.cc" # Just for runner. "graph/passes/pass_utils.cc" + "graph/passes/mds_pass.cc" + "graph/passes/mds_kernels/mds_utils.cc" + "graph/passes/mds_kernels/variable_mds_kernel.cc" + "graph/passes/mds_kernels/conv2d_mds_kernel.cc" + "graph/passes/mds_kernels/base_mds_kernel.cc" "host_kernels/add_kernel.cc" "host_kernels/broadcast_args_kernel.cc" "host_kernels/broadcast_gradient_args_kernel.cc" diff --git a/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc b/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc index d775309d..d93463af 100755 --- a/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc +++ b/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc @@ -76,5 +76,13 @@ Status GeLocalOpsKernelInfoStore::DestroySession(const map &sess // Do nothing return SUCCESS; } +Status GeLocalOpsKernelInfoStore::SetCutSupportedInfo(const NodePtr &node) { + // TODO: + // 1. Whether the variable type is identified as a trainable variable + // 2, whether to turn on smdp1 and 3 + // To meet the above two points, set the current variable + // node to be tangent in the variable segmentation information + return SUCCESS; +} } // namespace ge_local } // namespace ge diff --git a/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.h b/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.h index d35b01c7..51c68422 100755 --- a/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.h +++ b/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.h @@ -86,6 +86,8 @@ class GE_FUNC_VISIBILITY GeLocalOpsKernelInfoStore : public OpsKernelInfoStore { */ Status DestroySession(const std::map &session_options) override; + Status SetCutSupportedInfo(const ge::NodePtr &node) override; + // Copy prohibited GeLocalOpsKernelInfoStore(const GeLocalOpsKernelInfoStore &ops_kernel_store) = delete; diff --git a/ge/graph/execute/graph_execute.cc b/ge/graph/execute/graph_execute.cc index 03abf91f..6a53c51c 100755 --- a/ge/graph/execute/graph_execute.cc +++ b/ge/graph/execute/graph_execute.cc @@ -22,8 +22,18 @@ #include "graph/load/model_manager/model_manager.h" #include "graph/load/model_manager/davinci_model.h" #include "common/profiling/profiling_manager.h" +#include "graph/debug/ge_attr_define.h" +#include "common/thread_pool.h" namespace ge { +namespace { +//deploy info +const char *const kAttrDeviceType = "_device_type"; +const char *const kAttrDeviceId = "_device_id"; +const char *const kAttrGraphName = "_graph_name"; +const char *const kAttrGraphInputs = "_graph_inputs"; +const char *const kAttrNeedReturnResult = "_need_return_result"; +} using Uint32Pair = pair; const uint32_t kInvalidModelId = UINT32_MAX; GraphExecutor::GraphExecutor() @@ -386,7 +396,14 @@ Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr & } last_graph_id_ = graph_id; GE_CHECK_NOTNULL_EXEC(ge_root_model, return FAILED); - Status ret = AsyncExecuteModel(ge_root_model, input_tensor, callback); + vector deployInfo; + ModelIdInfo model_id_info; + Status ret; + if (ge::AttrUtils::GetListNamedAttrs(ge_root_model->GetRootGraph(), ATTR_NAME_DEPLOY_INFO, deployInfo)) { + ret = AsyncMultiExecuteModel(ge_root_model, input_tensor, callback); + } else { + ret = AsyncExecuteModel(ge_root_model, GetExecuteModelId(ge_root_model), input_tensor, callback); + } if (ret != SUCCESS) { GELOGE(GE_GRAPH_SYNC_MODEL_FAILED, "[AsyncExecute][Model] Error! graph id:%u", graph_id); return GE_GRAPH_SYNC_MODEL_FAILED; @@ -522,10 +539,67 @@ Status GraphExecutor::SetCallback(uint32_t model_id, const GeRootModelPtr &ge_ro } return SUCCESS; } +Status GraphExecutor::AsyncMultiExecuteModel(const GeRootModelPtr &ge_root_model, const std::vector &inputs, + const RunAsyncCallback &callback) { + // get deploy number of model instance + auto root_graph = ge_root_model->GetRootGraph(); + vector deploy_info; + if (!ge::AttrUtils::GetListNamedAttrs(root_graph, ATTR_NAME_DEPLOY_INFO, deploy_info) || deploy_info.empty()) { + GELOGE(FAILED, "[AsyncMultiExecuteModel] graph %s has invalid deploy attr %s", root_graph->GetName().c_str(), + ATTR_NAME_DEPLOY_INFO.c_str()); + return FAILED; + } + auto thread_instances_size = deploy_info.size(); + auto model_ids = ge_root_model->GetAllModelId(); + if (model_ids.size() != thread_instances_size) { + GELOGE(FAILED, + "[AsyncMultiExecuteModel] something wrong, attr deploy numbers %zu should be equal to loaded models %zu", + thread_instances_size, model_ids.size()); + return FAILED; + } + + ThreadPool executor(thread_instances_size); + std::vector> vector_future; + for (size_t i = 0; i < thread_instances_size; ++i) { + auto thread_instance = deploy_info[i]; + std::vector graph_inputs; + if (ge::AttrUtils::MutableListTensor(thread_instance, kAttrGraphInputs, graph_inputs)) { + std::vector graph_input_updated(inputs.begin(), inputs.end()); + for (const auto &ge_tensor_ptr : graph_inputs) { + graph_input_updated.push_back(TensorAdapter::AsTensor(*ge_tensor_ptr)); + } + GraphExecutor graph_executor; + ExecuteModelFunc execute_model_func(&GraphExecutor::AsyncExecuteModel); + std::future f; + bool need_return_result = false; + if ((ge::AttrUtils::GetBool(thread_instance, kAttrNeedReturnResult, need_return_result) && need_return_result)) { + f = executor.commit(execute_model_func, &graph_executor, ge_root_model, model_ids[i], graph_input_updated, + callback); + } else { + RunAsyncCallback callback_stub; + f = executor.commit(execute_model_func, &graph_executor, ge_root_model, model_ids[i], graph_input_updated, + callback_stub); + } + if (!f.valid()) { + GELOGE(FAILED, "[Call][Commit] failed, Future is invalid"); + return FAILED; + } + vector_future.emplace_back(std::move(f)); + } + } + for (size_t i = 0; i < vector_future.size(); ++i) { + Status ret_status = vector_future[i].get(); + if (ret_status != SUCCESS) { + REPORT_CALL_ERROR("E19999", " Execute multi model %zu failed", i); + GELOGE(ret_status, "[AsyncMultiExecuteModel] Execute multi model failed", i); + return ret_status; + } + } + return SUCCESS; +} -Status GraphExecutor::AsyncExecuteModel(const GeRootModelPtr &ge_root_model, const std::vector &inputs, - const RunAsyncCallback &callback) { - uint32_t model_id = GetExecuteModelId(ge_root_model); +Status GraphExecutor::AsyncExecuteModel(const GeRootModelPtr &ge_root_model, uint32_t model_id, + const std::vector &inputs, const RunAsyncCallback &callback) { if (model_id == kInvalidModelId) { GELOGE(INTERNAL_ERROR, "No valid model id."); return INTERNAL_ERROR; diff --git a/ge/graph/execute/graph_execute.h b/ge/graph/execute/graph_execute.h index 56e322f1..0c872202 100755 --- a/ge/graph/execute/graph_execute.h +++ b/ge/graph/execute/graph_execute.h @@ -136,8 +136,10 @@ class GraphExecutor { Status SyncExecuteModel(uint32_t model_id, const std::vector &input_tensor, std::vector &output_tensor); - Status AsyncExecuteModel(const GeRootModelPtr &ge_root_model, const std::vector &input_tensor, + Status AsyncExecuteModel(const GeRootModelPtr &ge_root_model, uint32_t model_id, const std::vector &input_tensor, const RunAsyncCallback &callback); + Status AsyncMultiExecuteModel(const GeRootModelPtr &ge_root_model, const std::vector &input_tensor, + const RunAsyncCallback &callback); void InitModelIdInfo(std::vector &out_model_id_info, std::vector &sub_graph_vec, uint32_t output_size); @@ -170,6 +172,11 @@ class GraphExecutor { std::vector buffer_addr_; std::vector buffer_size_; }; +using ExecuteModelFunc = std::function &inputs, + const RunAsyncCallback &callback)>; } // namespace ge #endif // GE_GRAPH_EXECUTE_GRAPH_EXECUTE_H_ diff --git a/ge/graph/execute/model_executor.cc b/ge/graph/execute/model_executor.cc index 993ba8c3..3f651f39 100644 --- a/ge/graph/execute/model_executor.cc +++ b/ge/graph/execute/model_executor.cc @@ -325,34 +325,45 @@ Status ModelExecutor::RunGraphWithStream(const GraphNodePtr &graph_node, GraphId Status ModelExecutor::ModelLoadSync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { ge_root_model->SetIsSpecificStream(graph_node->IsSpecificStream()); - return ModelLoad(ge_root_model, graph_node, graph_run_listener_); + return ModelLoad(ge_root_model, graph_node, false); } Status ModelExecutor::ModelLoadAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { - auto listener = MakeShared(); - GE_CHECK_NOTNULL(listener); - return ModelLoad(ge_root_model, graph_node, listener); + return ModelLoad(ge_root_model, graph_node, true); } Status ModelExecutor::ModelLoad(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node, - const std::shared_ptr &listener) { + bool is_async) { ge_root_model->SetTrainFlag(train_graph_flag_); bool is_unknown_shape = false; GE_CHK_STATUS_RET(ge_root_model->CheckIsUnknownShape(is_unknown_shape)); + auto root_graph = ge_root_model->GetRootGraph(); if (!is_unknown_shape) { if (getenv(kEnvGeuseStaticMemory) != nullptr) { GELOGI("[LoadGraph] GE_USE_STATIC_MEMORY is seted."); } else { - auto root_graph = ge_root_model->GetRootGraph(); GE_CHECK_NOTNULL(root_graph); auto name_to_model = ge_root_model->GetSubgraphInstanceNameToModel(); GeModelPtr ge_model = name_to_model[root_graph->GetName()]; GE_CHK_STATUS_RET(CheckAndReleaseMemory(ge_model, graph_node)); } } + std::shared_ptr listener = + is_async ? std::dynamic_pointer_cast(MakeShared()) : std::dynamic_pointer_cast< + ModelListener>(graph_run_listener_); GE_TIMESTAMP_START(LoadModelOnline); uint32_t model_id = INVALID_MODEL_ID; - Status ret = GraphLoader::LoadModelOnline(model_id, ge_root_model, listener); + vector deployInfo; + Status ret; + if (ge::AttrUtils::GetListNamedAttrs(root_graph, ATTR_NAME_DEPLOY_INFO, deployInfo)) { + ret = GraphLoader::LoadMultiModelOnline(ge_root_model, is_async); + } else { + ret = GraphLoader::LoadModelOnline(model_id, + ge_root_model, + listener, + GetContext().DeviceId(), + kInvalidDieId); + } GE_TIMESTAMP_EVENT_END(LoadModelOnline, "GraphLoader::LoadModelOnline"); if (ret != SUCCESS) { GELOGE(ret, "[Load][ModelOnline] Failed, model_id:%u", model_id); @@ -360,7 +371,6 @@ Status ModelExecutor::ModelLoad(const GeRootModelPtr &ge_root_model, const Graph return ret; } graph_node->SetLoadFlag(true); - ge_root_model->SetModelId(model_id); graph_node->SetGeRootModel(ge_root_model); AddGraphNode(graph_node->GetGraphId(), graph_node); return SUCCESS; diff --git a/ge/graph/execute/model_executor.h b/ge/graph/execute/model_executor.h index f11441e9..32957dac 100644 --- a/ge/graph/execute/model_executor.h +++ b/ge/graph/execute/model_executor.h @@ -98,8 +98,7 @@ class ModelExecutor : public Executor { Status ModelLoadSync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node); Status ModelLoadAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node); - Status ModelLoad(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node, - const std::shared_ptr &listener); + Status ModelLoad(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node, bool is_async); Status UnloadModel(const GeRootModelPtr &ge_root_model, uint32_t graph_id); diff --git a/ge/graph/load/graph_loader.cc b/ge/graph/load/graph_loader.cc index b2a61106..df60df4c 100755 --- a/ge/graph/load/graph_loader.cc +++ b/ge/graph/load/graph_loader.cc @@ -18,14 +18,24 @@ #include #include +#include #include "framework/common/helper/model_helper.h" #include "common/model_parser/model_parser.h" #include "graph/ge_context.h" #include "graph/load/model_manager/model_manager.h" #include "graph/manager/graph_var_manager.h" +#include "graph/debug/ge_attr_define.h" +#include "common/thread_pool.h" namespace ge { +namespace { +//deploy info +const char *const kAttrDeviceType = "_device_type"; +const char *const kAttrDeviceId = "_device_id"; +const char *const kAttrGraphName = "_graph_name"; +const char *const kAttrGraphInputs = "_graph_inputs"; +} Status GraphLoader::UnloadModel(uint32_t model_id) { auto model_manager = ModelManager::GetInstance(); GE_CHECK_NOTNULL(model_manager); @@ -45,43 +55,81 @@ Status GraphLoader::UnloadModel(uint32_t model_id) { return SUCCESS; } -Status GraphLoader::LoadModelOnline(uint32_t &model_id, const std::shared_ptr &ge_root_model_ptr, - const std::shared_ptr &listener) { - GELOGI("Load model online begin."); - rtError_t rt_ret = rtSetDevice(GetContext().DeviceId()); - if (rt_ret != RT_ERROR_NONE) { - REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); - GELOGE(RT_FAILED, "[Call][RtSetDevice] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); +Status GraphLoader::SetDevice(uint32_t device_id, int64_t die_id) { + if (device_id != kInvalidDeviceId && die_id != kInvalidDieId) { + rtError_t rt_ret = rtSetDevice(device_id, kMultiMode); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, ret:0x%X", device_id, rt_ret); + GELOGE(RT_FAILED, "[Call][rtSetDevice] failed, device_id:%u, ret:0x%X", device_id, rt_ret); + return RT_FAILED; + } + rt_ret = rtSetDieId(die_id); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtSetDieId failed, device_id:%u, ret:0x%X", die_id, rt_ret); + GELOGE(RT_FAILED, "[Call][RtSetDevice] rtSetDieId, device_id:%u, ret:0x%X", die_id, rt_ret); + return RT_FAILED; + } + } else if (device_id != kInvalidDeviceId && die_id == kInvalidDieId) { + rtError_t rt_ret = rtSetDevice(device_id); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, ret:0x%X", device_id, rt_ret); + GELOGE(RT_FAILED, "[Call][RtSetDevice] failed, device_id:%u, ret:0x%X", device_id, rt_ret); + return RT_FAILED; + } + } else { + REPORT_CALL_ERROR("E19999", "Call SetDevice failed, device_id:%u, die_id:%ld", device_id, die_id); + GELOGE(RT_FAILED, "[Call][SetDevice] failed, device_id:%u, die_id:%ld", device_id, die_id); return RT_FAILED; } + return SUCCESS; +} + +Status GraphLoader::ResetDevice(uint32_t device_id, int64_t die_id) { + if (die_id != kInvalidDieId) { + rtError_t rt_ret = rtDieReset(die_id); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, ret:0x%X", die_id, rt_ret); + GELOGE(RT_FAILED, "[Call][RtSetDevice] failed, device_id:%u, ret:0x%X", die_id, rt_ret); + return RT_FAILED; + } + } else { + rtError_t rt_ret = rtDeviceReset(device_id); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, ret:0x%X", device_id, rt_ret); + GELOGE(RT_FAILED, "[Call][RtSetDevice] failed, device_id:%u, ret:0x%X", device_id, rt_ret); + return RT_FAILED; + } + } + return SUCCESS; +} + +Status GraphLoader::LoadModelOnline(uint32_t &model_id, + const std::shared_ptr &ge_root_model_ptr, + const std::shared_ptr &listener, + uint32_t device_id, + int64_t die_id) { + GELOGI("Load model online begin."); if (ge_root_model_ptr == nullptr) { REPORT_INNER_ERROR("E19999", "Check param ge_root_model_ptr nullptr, check invalid"); GELOGE(GE_GRAPH_PARAM_NULLPTR, "[LoadGraph][Check][Param] GE load graph model_ptr is nullptr."); return GE_GRAPH_PARAM_NULLPTR; } - + if (SetDevice(device_id, die_id) != SUCCESS) { + REPORT_CALL_ERROR("E19999", "Call SetDevice failed, device_id:%u", device_id); + GELOGE(RT_FAILED, "[Call][SetDevice] failed, device_id:%u", device_id); + return RT_FAILED; + } + GE_MAKE_GUARD(reset_device, [&] { GE_CHK_RT(ResetDevice(device_id, die_id)); }); auto model_manager = ModelManager::GetInstance(); GE_CHECK_NOTNULL(model_manager); - Status ret = model_manager->LoadModelOnline(model_id, ge_root_model_ptr, listener); + Status ret = model_manager->LoadModelOnline(model_id, ge_root_model_ptr, listener,device_id, die_id); if (ret != SUCCESS) { GELOGE(ret, "[Load][Model] Online failed. ret = %u, model_id:%u", ret, model_id); - rt_ret = rtDeviceReset(GetContext().DeviceId()); - if (rt_ret != RT_ERROR_NONE) { - REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X", - GetContext().DeviceId(), rt_ret); - GELOGE(RT_FAILED, "[Call][RtDeviceReset] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); - } return ret; } - + ge_root_model_ptr->SetModelId(model_id); if (ge_root_model_ptr->IsSpecificStream()) { GELOGI("No need to start a new thread to run model in specific scene."); - rt_ret = rtDeviceReset(GetContext().DeviceId()); - if (rt_ret != RT_ERROR_NONE) { - REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X", - GetContext().DeviceId(), rt_ret); - GELOGE(RT_FAILED, "[Call][RtDeviceReset] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); - } return SUCCESS; } ret = model_manager->Start(model_id); @@ -89,25 +137,77 @@ Status GraphLoader::LoadModelOnline(uint32_t &model_id, const std::shared_ptrUnload(model_id) != SUCCESS) { GELOGE(ret, "[Unload][Model] failed while trying to unload after a failed start, model_id:%u.", model_id); } - - rt_ret = rtDeviceReset(GetContext().DeviceId()); - if (rt_ret != RT_ERROR_NONE) { - REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X", - GetContext().DeviceId(), rt_ret); - GELOGE(RT_FAILED, "[Call][RtDeviceReset] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); - } - GELOGE(ret, "[Start][Model] failed, model_id:%u.", model_id); return ret; } - rt_ret = rtDeviceReset(GetContext().DeviceId()); - if (rt_ret != RT_ERROR_NONE) { - REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X", - GetContext().DeviceId(), rt_ret); - GELOGE(RT_FAILED, "[Call][RtDeviceReset] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); - return RT_FAILED; - } GELOGI("Load model online success, model_id:%u.", model_id); + return SUCCESS; +} + +Status GraphLoader::LoadMultiModelOnline(const std::shared_ptr &ge_root_model, bool is_async) { + // get deploy number of model instance + auto root_graph = ge_root_model->GetRootGraph(); + vector deploy_info; + if (!ge::AttrUtils::GetListNamedAttrs(root_graph, ATTR_NAME_DEPLOY_INFO, deploy_info) || deploy_info.empty()) { + GELOGE(FAILED, "[LoadMultiModelOnline] Load multi model failed, graph %s has invalid deploy attr %s", + root_graph->GetName().c_str(), ATTR_NAME_DEPLOY_INFO.c_str()); + return FAILED; + } + auto thread_instances_size = deploy_info.size(); + auto device_id_fission_from = GetContext().DeviceId(); + GELOGI("Graph %s need to load model %zu times, and fission from device %u.", root_graph->GetName().c_str(), + thread_instances_size, device_id_fission_from); + ThreadPool executor(thread_instances_size); + std::vector> vector_future; + GE_TIMESTAMP_START(LoadModelOnline); + for (size_t i = 0; i < thread_instances_size; ++i) { + auto thread_instance = deploy_info[i]; + std::string device_type; + ModelIdInfo model_id_info; + std::shared_ptr listener; + if (is_async) { + listener = MakeShared(); + GE_CHECK_NOTNULL(listener); + } else { + // TODO: GraphModelListener for sync + } + int64_t device_id_fissioned = kInvalidDieId; + if (!ge::AttrUtils::GetInt(thread_instance, kAttrDeviceId, device_id_fissioned) || + device_id_fissioned == kInvalidDieId) { + REPORT_CALL_ERROR("E19999", "graph %s has invalid deploy attr %s", root_graph->GetName().c_str(), + ATTR_NAME_DEPLOY_INFO.c_str()); + GELOGE(GRAPH_FAILED, "[LoadMultiModelOnline] graph %s has invalid deploy attr %s", root_graph->GetName().c_str(), + ATTR_NAME_DEPLOY_INFO.c_str()); + return GRAPH_FAILED; + }; + if (ge::AttrUtils::GetStr(thread_instance, kAttrDeviceType, device_type) && device_type == kMultiMode) { + std::future f = executor.commit(GraphLoader::LoadModelOnline, model_id_info.model_id, ge_root_model, + listener, device_id_fission_from, device_id_fissioned); + if (!f.valid()) { + GELOGE(FAILED, "[Call][Commit] failed, Future is invalid"); + return FAILED; + } + vector_future.emplace_back(std::move(f)); + } else { + std::future f = executor.commit(GraphLoader::LoadModelOnline, model_id_info.model_id, ge_root_model, + listener, device_id_fissioned, kInvalidDieId); + if (!f.valid()) { + GELOGE(FAILED, "[Call][Commit] failed, Future is invalid"); + return FAILED; + } + vector_future.emplace_back(std::move(f)); + } + } + GE_TIMESTAMP_EVENT_END(LoadModelOnline, "GraphLoader::LoadModelOnline"); + + for (size_t i = 0; i < vector_future.size(); ++i) { + Status ret_status = vector_future[i].get(); + if (ret_status != SUCCESS) { + REPORT_CALL_ERROR("E19999", " Load multi model %zu failed", i); + GELOGE(ret_status, "[LoadMultiModelOnline] Load multi model failed", i); + return ret_status; + } + } return SUCCESS; } diff --git a/ge/graph/load/graph_loader.h b/ge/graph/load/graph_loader.h index f6324c98..27a5bf2e 100755 --- a/ge/graph/load/graph_loader.h +++ b/ge/graph/load/graph_loader.h @@ -30,6 +30,13 @@ #include "runtime/mem.h" namespace ge { +namespace { +const int64_t kInvalidDieId = -1; +const uint32_t kInvalidDeviceId = UINT32_MAX; +const char* kMultiMode ="MultiMode"; +const char* kSingleMode ="SingleMode"; +} + class GraphLoader { public: GraphLoader() = default; @@ -64,9 +71,12 @@ class GraphLoader { static Status DestroyAicpuKernel(uint64_t session_id, uint32_t model_id, uint32_t sub_model_id); static Status DestroyAicpuSessionForInfer(uint32_t model_id); - static Status LoadModelOnline(uint32_t &model_id, const std::shared_ptr &ge_root_model, - const std::shared_ptr &listener); + const std::shared_ptr &listener, uint32_t device_id, + int64_t die_id = kInvalidDieId); + static Status SetDevice(uint32_t device_id, int64_t die_id); + static Status ResetDevice(uint32_t device_id, int64_t die_id); + static Status LoadMultiModelOnline(const std::shared_ptr &ge_root_model_ptr, bool is_async); }; } // namespace ge #endif // GE_GRAPH_LOAD_GRAPH_LOADER_H_ diff --git a/ge/graph/load/model_manager/davinci_model.cc b/ge/graph/load/model_manager/davinci_model.cc index 495ec28e..6c8e6c62 100755 --- a/ge/graph/load/model_manager/davinci_model.cc +++ b/ge/graph/load/model_manager/davinci_model.cc @@ -444,16 +444,16 @@ Status DavinciModel::InitFeatureMapAndP2PMem(void *dev_ptr, size_t mem_size) { Status DavinciModel::InitVariableMem() { // malloc variable memory base - var_mem_base_ = VarManager::Instance(session_id_)->GetVarMemoryBase(RT_MEMORY_HBM); + var_mem_base_ = VarManager::Instance(session_id_)->GetVarMemoryBase(RT_MEMORY_HBM, GetDeviceId()); if (TotalVarMemSize() && (var_mem_base_ == nullptr)) { - Status ret = VarManager::Instance(session_id_)->MallocVarMemory(TotalVarMemSize()); + Status ret = VarManager::Instance(session_id_)->MallocVarMemory(TotalVarMemSize(), GetDeviceId()); if (ret != SUCCESS) { REPORT_CALL_ERROR("E19999", "MallocVarMemory fail, var_size:%zu, model_id:%u, check invalid", TotalVarMemSize(), model_id_); GELOGE(ret, "[Malloc][VarMemory] failed, var_size:%zu, model_id:%u", TotalVarMemSize(), model_id_); return ret; } - var_mem_base_ = VarManager::Instance(session_id_)->GetVarMemoryBase(RT_MEMORY_HBM); + var_mem_base_ = VarManager::Instance(session_id_)->GetVarMemoryBase(RT_MEMORY_HBM, GetDeviceId()); GEEVENT("[IMAS]InitVariableMem graph_%u MallocMemory type[V] memaddr[%p] mem_size[%zu]", runtime_param_.graph_id, var_mem_base_, TotalVarMemSize()); } @@ -2819,18 +2819,16 @@ void *DavinciModel::Run(DavinciModel *model) { bool seq_end_flag = false; uint32_t model_id = model->Id(); uint32_t device_id = model->GetDeviceId(); + int64_t die_id = model->GetDieId(); ErrorManager::GetInstance().SetErrorContext(model->GetErrorContext()); GELOGI("Model Run thread start, model_id:%u.", model_id); - rtError_t rt_ret = rtSetDevice(static_cast(device_id)); - if (rt_ret != RT_ERROR_NONE) { - - GELOGE(FAILED, "[Run][Rtsetdevice] failed, model_id:%u, device_id:%u.", model_id, device_id); + if (GraphLoader::SetDevice(device_id, die_id) != SUCCESS) { + GELOGE(FAILED, "[Run][Setdevice] failed, model_id:%u, device_id:%u die_id%ld.", model_id, device_id, die_id); return nullptr; } // DeviceReset before thread run finished! - GE_MAKE_GUARD(not_used_var, [&] { GE_CHK_RT(rtDeviceReset(device_id)); }); - + GE_MAKE_GUARD(reset_device, [&] { GE_CHK_RT(GraphLoader::ResetDevice(device_id, model->GetDieId())); }); ErrorManager::GetInstance().SetStage(error_message::kModelExecute, error_message::kModelExecute); while (model->RunFlag()) { // Model hasn't truly started runing before received data @@ -2886,7 +2884,7 @@ void *DavinciModel::Run(DavinciModel *model) { GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), model->SetProfileTime(MODEL_INFER_START)); GE_TIMESTAMP_START(rtModelExecute); GELOGI("rtModelExecute start."); - rt_ret = rtModelExecute(model->rt_model_handle_, model->rt_model_stream_, 0); + auto rt_ret = rtModelExecute(model->rt_model_handle_, model->rt_model_stream_, 0); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, rslt_flg = false; (void)model->ReturnResult(current_data.index, false, false, data_wrapper->GetOutput()); continue); diff --git a/ge/graph/load/model_manager/davinci_model.h b/ge/graph/load/model_manager/davinci_model.h index 76b0beef..08d3eea0 100755 --- a/ge/graph/load/model_manager/davinci_model.h +++ b/ge/graph/load/model_manager/davinci_model.h @@ -59,7 +59,9 @@ namespace ge { // op debug need 2048 bits buffer const size_t kOpDebugMemorySize = 2048UL; const size_t kDebugP2pSize = 8UL; +const size_t kDebugP2pSize = 8UL; +const int64_t kInvalidDieId = -1; typedef enum tagModelProcStage { MODEL_LOAD_START = 1, MODEL_LOAD_END, @@ -441,13 +443,17 @@ class DavinciModel { /// @return void /// void SetDeviceId(uint32_t device_id) { device_id_ = device_id; } + void SetDieId(int64_t die_id) { die_id_ = die_id; } /// /// @ingroup ge /// @brief Get device Id /// @return device id /// - uint32_t GetDeviceId() const { return device_id_; } + uint32_t GetDeviceId() const { + return die_id_ == kInvalidDieId ? device_id_ : die_id_; + } + int64_t GetDieId() const { return die_id_; } bool NeedDestroyAicpuKernel() const { return need_destroy_aicpu_kernel_; } @@ -1010,6 +1016,7 @@ class DavinciModel { struct error_message::Context error_context_; uint32_t device_id_; + int64_t die_id_ = kInvalidDieId; mutex flowctrl_op_index_internal_map_mutex_; map flowctrl_op_index_internal_map_; diff --git a/ge/graph/load/model_manager/model_manager.cc b/ge/graph/load/model_manager/model_manager.cc index d0d88e66..25499924 100755 --- a/ge/graph/load/model_manager/model_manager.cc +++ b/ge/graph/load/model_manager/model_manager.cc @@ -324,7 +324,7 @@ bool ModelManager::IsNeedHybridLoad(ge::GeRootModel &ge_root_model) { /// @return Status run result /// Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr &ge_root_model, - std::shared_ptr listener) { + std::shared_ptr listener, uint32_t &device_id, int64_t die_id) { GE_CHK_BOOL_RET_STATUS(listener.get() != nullptr, PARAM_INVALID, "[Check][Param] Param incorrect, listener is null"); if (model_id == INVALID_MODEL_ID) { GenModelId(&model_id); @@ -342,7 +342,8 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptrSetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * kTimeSpecNano + timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond davinci_model->SetId(model_id); - davinci_model->SetDeviceId(GetContext().DeviceId()); + davinci_model->SetDeviceId(device_id); + davinci_model->SetDieId(die_id); auto root_graph = ge_root_model->GetRootGraph(); GE_CHECK_NOTNULL(root_graph); diff --git a/ge/graph/load/model_manager/model_manager.h b/ge/graph/load/model_manager/model_manager.h index 6389d6db..9448ce02 100755 --- a/ge/graph/load/model_manager/model_manager.h +++ b/ge/graph/load/model_manager/model_manager.h @@ -71,7 +71,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { /// @author @ /// ge::Status LoadModelOnline(uint32_t &model_id, const std::shared_ptr &ge_root_model, - std::shared_ptr listener); + std::shared_ptr listener,uint32_t &device_id, int64_t die_id); ge::Status DoLoadHybridModelOnline(uint32_t model_id, const string &model_name, const shared_ptr &ge_root_model, diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index fa140bfe..4d03d5a0 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -98,6 +98,7 @@ #include "graph/passes/hccl_continuous_memcpy_pass.h" #include "graph/passes/parallel_group_pass.h" #include "graph/passes/buffer_pool_memory_pass.h" +#include "graph/passes/mds_pass.h" #include "graph/build/label_allocator.h" #include "graph/utils/tensor_adapter.h" #include "inc/pass_manager.h" @@ -110,6 +111,7 @@ #include "external/graph/types.h" #include "common/util/error_manager/error_manager.h" #include "common/profiling/profiling_manager.h" +#include "graph/debug/ge_attr_define.h" namespace { const char *const kSummary = "Summary"; @@ -1087,7 +1089,6 @@ Status GraphManager::LoadGraph(const GeRootModelPtr &ge_root_model, const GraphN if (!options_.run_graph_flag) { return SUCCESS; } - ErrorManager::GetInstance().SetStage(error_message::kModelLoad, error_message::kModelLoad); GE_CHECK_NOTNULL(executor_); return executor_->LoadGraph(ge_root_model, graph_node); @@ -2816,9 +2817,40 @@ const map *GraphManager::GetGraphOptions(uint32_t grap } void GraphManager::SetOptionsRunGraphFlag(bool run_graph_flag) { options_.run_graph_flag = run_graph_flag; } +Status GraphManager::SetNodeCutInfo(ComputeGraphPtr &compute_graph) { + auto instance_ptr = ge::GELib::GetInstance(); + if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { + REPORT_INNER_ERROR("E19999", "GeLib is not init before, check invalid"); + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Check][Param] GE is not initialized"); + return FAILED; + } + for (const auto &node : compute_graph->GetDirectNode()) { + GE_CHECK_NOTNULL(node); + auto kernel_lib_name = node->GetOpDesc()->GetOpKernelLibName(); + OpsKernelInfoStorePtr kernel_info = instance_ptr->OpsKernelManagerObj().GetOpsKernelInfoStore(kernel_lib_name); + if (kernel_info == nullptr) { + REPORT_INNER_ERROR("E19999", "Find ops kernel by name:%s failed", + kernel_lib_name.c_str()); + GELOGE(FAILED, "[Get][OpsKernelInfoStore] by name:%s failed", kernel_lib_name.c_str()); + return FAILED; + } + GE_CHK_STATUS_RET(kernel_info->SetCutSupportedInfo(node)); + } +} Status GraphManager::OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, uint64_t session_id) { + GE_TIMESTAMP_START(MDS); + // Set the cut support information based on the engine of the node + EnginePlacer engine_placer; + engine_placer.SetComputeGraph(compute_graph); + GE_CHK_STATUS_RET(engine_placer.Run()); + GE_CHK_STATUS_RET(SetNodeCutInfo(compute_graph)); + // mds pass + PassManager graph_pass; + GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeSubgraph::MDS", new (std::nothrow) ModelDeploySchedulerPass)) + GE_CHK_STATUS_RET(graph_pass.Run(compute_graph)); + GE_TIMESTAMP_EVENT_END(MDS, "OptimizeSubgraph::MDS"); // graph partition // Stage partition, only for root graph GE_TIMESTAMP_START(StagePartition); diff --git a/ge/graph/manager/graph_manager.h b/ge/graph/manager/graph_manager.h index e7cd88a9..d304d4ae 100644 --- a/ge/graph/manager/graph_manager.h +++ b/ge/graph/manager/graph_manager.h @@ -242,7 +242,7 @@ class GraphManager { uint64_t session_id = INVALID_SESSION_ID); Status OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, uint64_t session_id); - + Status SetNodeCutInfo (ComputeGraphPtr &compute_graph); Status Build(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, GeRootModelPtr &ge_root_model, uint64_t session_id); diff --git a/ge/graph/manager/graph_mem_allocator.cc b/ge/graph/manager/graph_mem_allocator.cc index dd38274e..b0c5399e 100755 --- a/ge/graph/manager/graph_mem_allocator.cc +++ b/ge/graph/manager/graph_mem_allocator.cc @@ -23,12 +23,16 @@ Status MemoryAllocator::Initialize(uint32_t device_id) { GELOGI("MemoryAllocator::Initialize"); // when redo Initialize free memory - for (auto &it : memory_base_map_) { - if (FreeMemory(it.second.memory_addr_, device_id) != ge::SUCCESS) { - GELOGW("Initialize: FreeMemory failed"); + for (auto &it_map : deviceid_2_memory_bases_map_) { + + for (auto &it : it_map.second) { + if (FreeMemory(it.second.memory_addr_, device_id) != ge::SUCCESS) { + GELOGW("Initialize: FreeMemory failed"); + } } + it_map.second.clear(); } - memory_base_map_.clear(); + deviceid_2_memory_bases_map_.clear(); return SUCCESS; } @@ -36,12 +40,16 @@ void MemoryAllocator::Finalize(uint32_t device_id) { GELOGI("MemoryAllocator::Finalize"); // free memory - for (auto &it : memory_base_map_) { - if (FreeMemory(it.second.memory_addr_, device_id) != ge::SUCCESS) { - GELOGW("Finalize: FreeMemory failed"); + for (auto &it_map : deviceid_2_memory_bases_map_) { + + for (auto &it : it_map.second) { + if (FreeMemory(it.second.memory_addr_, device_id) != ge::SUCCESS) { + GELOGW("Finalize: FreeMemory failed"); + } } + it_map.second.clear(); } - memory_base_map_.clear(); + deviceid_2_memory_bases_map_.clear(); } uint8_t *MemoryAllocator::MallocMemory(const string &purpose, size_t memory_size, uint32_t device_id) const { @@ -75,12 +83,16 @@ Status MemoryAllocator::FreeMemory(uint8_t *memory_addr, uint32_t device_id) con uint8_t *MemoryAllocator::MallocMemory(const string &purpose, const string &memory_key, size_t memory_size, uint32_t device_id) { - auto it = memory_base_map_.find(memory_key); - if (it != memory_base_map_.end()) { - it->second.memory_used_num_++; - return it->second.memory_addr_; + map memory_base_map; + auto it_map = deviceid_2_memory_bases_map_.find(device_id); + if (it_map != deviceid_2_memory_bases_map_.end()) { + memory_base_map = it_map->second; + auto it = it_map->second.find(memory_key); + if (it != it_map->second.end()) { + it->second.memory_used_num_++; + return it->second.memory_addr_; + } } - uint8_t *memory_addr = MallocMemory(purpose, memory_size, device_id); if (memory_addr == nullptr) { @@ -91,16 +103,27 @@ uint8_t *MemoryAllocator::MallocMemory(const string &purpose, const string &memo return nullptr; } - MemoryInfo memory_info(memory_addr, memory_size); + MemoryInfo memory_info(memory_addr, memory_size, device_id); memory_info.memory_used_num_++; - memory_base_map_[memory_key] = memory_info; + memory_base_map[memory_key] = memory_info; + deviceid_2_memory_bases_map_[device_id] = memory_base_map; mem_malloced_ = true; return memory_addr; } Status MemoryAllocator::FreeMemory(const string &memory_key, uint32_t device_id) { - auto it = memory_base_map_.find(memory_key); - if (it == memory_base_map_.end()) { + auto it_map = deviceid_2_memory_bases_map_.find(device_id); + if (it_map == deviceid_2_memory_bases_map_.end()){ + if (mem_malloced_) { + GELOGW( + "MemoryAllocator::FreeMemory failed," + " memory_key[%s] was not exist, device_id = %u.", + memory_key.c_str(), device_id); + } + return ge::INTERNAL_ERROR; + } + auto it = it_map->second.find(memory_key); + if (it == it_map->second.end()) { if (mem_malloced_) { GELOGW( "MemoryAllocator::FreeMemory failed," @@ -109,7 +132,6 @@ Status MemoryAllocator::FreeMemory(const string &memory_key, uint32_t device_id) } return ge::INTERNAL_ERROR; } - if (it->second.memory_used_num_ > 1) { GELOGW("MemoryAllocator::FreeMemory memory_key[%s] should not be released, reference count %d", memory_key.c_str(), it->second.memory_used_num_); @@ -129,20 +151,28 @@ Status MemoryAllocator::FreeMemory(const string &memory_key, uint32_t device_id) GELOGI("MemoryAllocator::FreeMemory device_id = %u", device_id); - memory_base_map_.erase(it); + it_map->second.erase(it); return ge::SUCCESS; } uint8_t *MemoryAllocator::GetMemoryAddr(const string &memory_key, uint32_t device_id) { - auto it = memory_base_map_.find(memory_key); - if (it == memory_base_map_.end()) { + auto it_map = deviceid_2_memory_bases_map_.find(device_id); + if (it_map == deviceid_2_memory_bases_map_.end()) { + + GELOGW( + "MemoryAllocator::GetMemoryAddr failed," + " memory_key[%s] was not exist, device_id = %u.", + memory_key.c_str(), device_id); + return nullptr; + } + auto it = it_map->second.find(memory_key); + if (it == it_map->second.end()) { GELOGW( "MemoryAllocator::GetMemoryAddr failed," " memory_key[%s] was not exist, device_id = %u.", memory_key.c_str(), device_id); return nullptr; } - return it->second.memory_addr_; } } // namespace ge diff --git a/ge/graph/manager/graph_mem_allocator.h b/ge/graph/manager/graph_mem_allocator.h index b6d73f0a..06a65bfb 100644 --- a/ge/graph/manager/graph_mem_allocator.h +++ b/ge/graph/manager/graph_mem_allocator.h @@ -32,10 +32,13 @@ namespace ge { class MemoryInfo { public: - MemoryInfo() : memory_addr_(nullptr), memory_size_(0), memory_used_num_(0) {} + MemoryInfo() : memory_addr_(nullptr), memory_size_(0), memory_used_num_(0), device_id_(0) {} MemoryInfo(uint8_t *memory_addr, size_t memory_size) - : memory_addr_(memory_addr), memory_size_(memory_size), memory_used_num_(0) {} + : memory_addr_(memory_addr), memory_size_(memory_size), memory_used_num_(0), device_id_(0) {} + + MemoryInfo(uint8_t *memory_addr, size_t memory_size, uint32_t device_id) + : memory_addr_(memory_addr), memory_size_(memory_size), device_id_(device_id), memory_used_num_(0) {} MemoryInfo &operator=(const MemoryInfo &op) { if (&op == this) { @@ -44,7 +47,7 @@ class MemoryInfo { this->memory_addr_ = op.memory_addr_; this->memory_size_ = op.memory_size_; - this->memory_used_num_ = op.memory_used_num_; + this->device_id_ = op.device_id_; return *this; } @@ -52,12 +55,14 @@ class MemoryInfo { this->memory_addr_ = op.memory_addr_; this->memory_size_ = op.memory_size_; this->memory_used_num_ = op.memory_used_num_; + this->device_id_ = op.device_id_; } virtual ~MemoryInfo() = default; uint8_t *memory_addr_; uint64_t memory_size_; int32_t memory_used_num_; + uint32_t device_id_; }; class MemoryAllocator { @@ -133,7 +138,7 @@ class MemoryAllocator { private: rtMemType_t memory_type_; bool mem_malloced_; - map memory_base_map_; + map> deviceid_2_memory_bases_map_; }; } // namespace ge diff --git a/ge/graph/manager/graph_var_manager.cc b/ge/graph/manager/graph_var_manager.cc index ce5b335e..958f08b1 100755 --- a/ge/graph/manager/graph_var_manager.cc +++ b/ge/graph/manager/graph_var_manager.cc @@ -348,7 +348,7 @@ ge::Status VarManager::Init(const uint32_t &version, const uint64_t &session_id, device_id_ = device_id; session_id_ = session_id; job_id_ = job_id; - var_resource_ = std::unique_ptr(new (std::nothrow) VarResource(session_id_)); + var_resource_ = std::unique_ptr(new(std::nothrow) VarResource(session_id_)); if (var_resource_ == nullptr) { GELOGW("VarManager init failed session id = %lu.", session_id); return ge::INTERNAL_ERROR; @@ -637,7 +637,7 @@ rtMemType_t VarManager::GetVarMemType(const int64_t &offset) { return var_resource_->GetVarMemType(offset); } -ge::Status VarManager::MallocVarMemory(size_t memory_size) { +ge::Status VarManager::MallocVarMemory(size_t memory_size, uint32_t device_id) { std::lock_guard lock(mutex_); uint8_t *var_mem_base = nullptr; string memory_key = std::to_string(session_id_); @@ -649,7 +649,7 @@ ge::Status VarManager::MallocVarMemory(size_t memory_size) { var_memory_size = (var_memory_size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize; const string purpose("variables and constant op memory in training network."); - var_mem_base = MemManager::Instance().MemInstance(RT_MEMORY_HBM).MallocMemory(purpose, memory_key, var_memory_size); + var_mem_base = MemManager::Instance().MemInstance(RT_MEMORY_HBM).MallocMemory(purpose, memory_key, var_memory_size, device_id); if (var_mem_base == nullptr) { GELOGE(ge::INTERNAL_ERROR, "[Malloc][VarMemory] failed, size:%zu, session_id:%s", var_memory_size, memory_key.c_str()); @@ -658,22 +658,22 @@ ge::Status VarManager::MallocVarMemory(size_t memory_size) { return SUCCESS; } -uint8_t *VarManager::GetVarMemoryBase(rtMemType_t memory_type) { +uint8_t *VarManager::GetVarMemoryBase(rtMemType_t memory_type, uint32_t device_id) { std::lock_guard lock(mutex_); if (memory_type == RT_MEMORY_RDMA_HBM) { return MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).GetRdmaBaseAddr(); } string memory_key = std::to_string(session_id_); - return MemManager::Instance().MemInstance(memory_type).GetMemoryAddr(memory_key); + return MemManager::Instance().MemInstance(memory_type).GetMemoryAddr(memory_key, device_id); } -uint8_t *VarManager::GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type) { +uint8_t *VarManager::GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type, uint32_t device_id) { std::lock_guard lock(mutex_); if (memory_type == RT_MEMORY_RDMA_HBM) { return logic_addr; } string mem_key = std::to_string(session_id_); - uint8_t *mem_base = MemManager::Instance().MemInstance(memory_type).GetMemoryAddr(mem_key); + uint8_t *mem_base = MemManager::Instance().MemInstance(memory_type).GetMemoryAddr(mem_key, device_id); if (mem_base == nullptr) { return nullptr; } diff --git a/ge/graph/manager/graph_var_manager.h b/ge/graph/manager/graph_var_manager.h index f0e3b89b..1841c336 100755 --- a/ge/graph/manager/graph_var_manager.h +++ b/ge/graph/manager/graph_var_manager.h @@ -231,7 +231,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { ge::Status RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc); - ge::Status MallocVarMemory(size_t memory_size = kMemoryVarManagerMallocSize); + ge::Status MallocVarMemory(size_t memory_size = kMemoryVarManagerMallocSize, uint32_t device_id = 0); ge::Status FreeVarMemory(); @@ -277,9 +277,9 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { rtMemType_t GetVarMemType(const int64_t &offset); - uint8_t *GetVarMemoryBase(rtMemType_t memory_type); + uint8_t *GetVarMemoryBase(rtMemType_t memory_type, uint32_t device_id = 0); - uint8_t *GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type); + uint8_t *GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type, uint32_t device_id = 0); Status GetAllVariables(std::map &all_variables); @@ -293,6 +293,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { size_t var_mem_logic_base_; size_t use_max_mem_size_; std::unique_ptr var_resource_; +// map> var_resource_map_; map mem_resource_map_; mutable std::recursive_mutex mutex_; diff --git a/ge/graph/passes/mds_kernels/base_mds_kernel.cc b/ge/graph/passes/mds_kernels/base_mds_kernel.cc new file mode 100644 index 00000000..53f47ac4 --- /dev/null +++ b/ge/graph/passes/mds_kernels/base_mds_kernel.cc @@ -0,0 +1,142 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "./base_mds_kernel.h" +namespace ge { +namespace mds_cut_pass { +shared_ptr GetKernelByType(const NodePtr &node) { + if (node == nullptr) { + REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid"); + GELOGE(FAILED, "[Check][Param] parameter node is nullptr."); + return nullptr; + } + KernelFactory &factory = KernelFactory::Instance(); + string type = node->GetType(); + if (type == FRAMEWORKOP) { + if (!ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type)) { + REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE.c_str(), + node->GetName().c_str(), node->GetType().c_str()); + return nullptr; + } + } + + return factory.Create(type); +} +} // namespace mds_cut_pass +shared_ptr DeploySchedulerKernel::Instance() { + static const std::shared_ptr instance_ptr = + shared_ptr(new (std::nothrow) DeploySchedulerKernel()); + return instance_ptr; +} + +Status DeploySchedulerKernel::CutN(const ge::NodePtr &node) { + GE_CHECK_NOTNULL(node); + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + for (auto &in_anchor : node->GetAllInDataAnchors()) { + GE_CHECK_NOTNULL(in_anchor); + auto src_anchor = in_anchor->GetPeerOutAnchor(); + if (src_anchor == nullptr) { + continue; + } + auto tensor_desc = op_desc->MutableInputDesc(in_anchor->GetIdx()); + auto src_node = src_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(src_node); + auto src_op_desc = src_node->GetOpDesc(); + auto src_tensor_desc = src_op_desc->MutableOutputDesc(src_anchor->GetIdx()); + GE_CHECK_NOTNULL(src_tensor_desc); + // peer out shape is cutted already + if (MdsUtils::IsDistributedDeploySupported(src_tensor_desc, kCutN)) { + if (MdsUtils::IsDistributedDeploySupported(tensor_desc, kCutN)) { + tensor_desc->SetShape(src_tensor_desc->GetShape()); + } else { + MDS_REQUIRE_SUCCESS( + MdsUtils::DataGather(src_anchor, in_anchor), "[CutN] failed to gather between node[%s][%d] to node[%s][%d]", + src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), in_anchor->GetIdx()); + } + } else { + if (MdsUtils::IsDistributedDeploySupported(tensor_desc, kCutN)) { + MDS_REQUIRE_SUCCESS(MdsUtils::DataSlice(src_anchor, in_anchor, input_node_), + "[CutN] failed to slice between node[%s][%d] to node[%s][%d]", + src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), + in_anchor->GetIdx()); + } else { + tensor_desc->SetShape(src_tensor_desc->GetShape()); + } + } + + // insert hcomallreduce for cutn + bool is_grad_compute_node = false; + if (ge::AttrUtils::GetBool(src_node->GetOpDesc(), ATTR_NAME_GRADIENT_NODE, is_grad_compute_node) && + is_grad_compute_node) { + MDS_REQUIRE_SUCCESS( + MdsUtils::DataReduce(src_anchor, in_anchor), "[CutN] failed to reduce between node[%s][%d] to node[%s][%d]", + src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), in_anchor->GetIdx()); + } + } + + // call infer shape, update output shape + MDS_REQUIRE_SUCCESS(node->InferShapeAndType(), "[CutN] %s call infershape failed", node->GetName().c_str()); + return SUCCESS; +} + +Status DeploySchedulerKernel::CutH(const ge::NodePtr &node) { + GE_CHECK_NOTNULL(node); + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + for (auto &in_anchor : node->GetAllInDataAnchors()) { + GE_CHECK_NOTNULL(in_anchor); + auto src_anchor = in_anchor->GetPeerOutAnchor(); + if (src_anchor == nullptr) { + continue; + } + auto tensor_desc = op_desc->MutableInputDesc(in_anchor->GetIdx()); + auto src_node = src_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(src_node); + auto src_op_desc = src_node->GetOpDesc(); + auto src_tensor_desc = src_op_desc->MutableOutputDesc(src_anchor->GetIdx()); + GE_CHECK_NOTNULL(src_tensor_desc); + // peer out shape is cutted already + if (MdsUtils::IsDistributedDeploySupported(src_tensor_desc, kCutH)) { + if (MdsUtils::IsDistributedDeploySupported(tensor_desc, kCutH)) { + MDS_REQUIRE_SUCCESS(HaloExchangeProcess(node, in_anchor->GetIdx()), + "[CutH] failed to do overlap between node[%s][%d] to node[%s][%d]", + src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), + in_anchor->GetIdx()); + } else { + MDS_REQUIRE_SUCCESS( + MdsUtils::DataGather(src_anchor, in_anchor), "[CutH] failed to gather between node[%s][%d] to node[%s][%d]", + src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), in_anchor->GetIdx()); + } + } else { + if (MdsUtils::IsDistributedDeploySupported(tensor_desc, kCutH)) { + MDS_REQUIRE_SUCCESS(MdsUtils::DataSlice(src_anchor, in_anchor, input_node_), + "[CutH] failed to slice between node[%s][%d] to node[%s][%d]", + src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), + in_anchor->GetIdx()); + } else { + MDS_REQUIRE_SUCCESS(HaloExchangeProcess(node, in_anchor->GetIdx(), true), + "[CutH] failed to do overlap between node[%s][%d] to node[%s][%d]", + src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), + in_anchor->GetIdx()); + } + } + } + + // call infer shape, update output shape + MDS_REQUIRE_SUCCESS(node->InferShapeAndType(), "[CutH] call infer shape failed", node->GetName().c_str()); + return SUCCESS; +} +} // namespace ge diff --git a/ge/graph/passes/mds_kernels/base_mds_kernel.h b/ge/graph/passes/mds_kernels/base_mds_kernel.h new file mode 100644 index 00000000..8f46a696 --- /dev/null +++ b/ge/graph/passes/mds_kernels/base_mds_kernel.h @@ -0,0 +1,76 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_BASE_MDS_KERNEL_H_ +#define MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_BASE_MDS_KERNEL_H_ + +#include + +#include "common/op/ge_op_utils.h" +#include "graph/compute_graph.h" +#include "graph/graph.h" +#include "graph/op_desc.h" +#include "graph/debug/ge_op_types.h" +#include "framework/common/types.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/graph_utils.h" +#include "graph/shape_refiner.h" +#include "../pass_utils.h" +#include "./mds_utils.h" +#include "./mds_kernel_factory.h" + +namespace ge { +class DeploySchedulerKernel { + public: + static shared_ptr Instance(); + + /// CutN imply + /// @param [in] node_ptr + virtual Status CutN(const ge::NodePtr &node_ptr); + + /// CutH imply + /// @param [in] node_ptr + virtual Status CutH(const ge::NodePtr &node_ptr); + + /// DynamicCutN imply + /// @param [in] node_ptr + virtual Status DynamicCutN(const ge::NodePtr &node_ptr); + + /// DynamicCutH imply + /// @param [in] node_ptr + virtual Status DynamicCutH(const ge::NodePtr &node_ptr); + + // halo exchange process + Status HaloExchangeProcess(NodePtr node, int64_t index, bool local_slice = false); + + NodePtr GetInputNode() { + return input_node_; + } + DeploySchedulerKernel &operator=(const DeploySchedulerKernel &kernel) = delete; + DeploySchedulerKernel(const DeploySchedulerKernel &kernel) = delete; + + protected: + DeploySchedulerKernel() = default; + virtual ~DeploySchedulerKernel() = default; + + private: + NodePtr input_node_ = nullptr; +}; + +namespace mds_cut_pass { +shared_ptr GetKernelByType(const NodePtr &node); +} +} // namespace ge +#endif // MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_BASE_MDS_KERNEL_H_ \ No newline at end of file diff --git a/ge/graph/passes/mds_kernels/conv2d_mds_kernel.cc b/ge/graph/passes/mds_kernels/conv2d_mds_kernel.cc new file mode 100644 index 00000000..29b1fd66 --- /dev/null +++ b/ge/graph/passes/mds_kernels/conv2d_mds_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "conv2d_mds_kernel.h" +#include "mds_kernel_factory.h" + +namespace ge { +Status Conv2dDeploySchedulerKernel::CutN(const ge::NodePtr node_ptr) { + return DeploySchedulerKernel::CutN(node_ptr); +} +Status Conv2dDeploySchedulerKernel::CutH(const ge::NodePtr node_ptr) { + return DeploySchedulerKernel::CutH(node_ptr); +} + +REGISTER_MDS_KERNEL(CONV2D, Conv2dDeploySchedulerKernel); +} // namespace ge + diff --git a/ge/graph/passes/mds_kernels/conv2d_mds_kernel.h b/ge/graph/passes/mds_kernels/conv2d_mds_kernel.h new file mode 100644 index 00000000..e20c5b5f --- /dev/null +++ b/ge/graph/passes/mds_kernels/conv2d_mds_kernel.h @@ -0,0 +1,29 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_CONV2D_MDS_KERNEL_H_ +#define MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_CONV2D_MDS_KERNEL_H_ + +#include "base_mds_kernel.h" + +namespace ge { +class Conv2dDeploySchedulerKernel : public DeploySchedulerKernel { + public: + Status CutN(const ge::NodePtr& node_ptr) override; + Status CutH(const ge::NodePtr& node_ptr) override; +}; +} // namespace ge +#endif //MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_CONV2D_MDS_KERNEL_H_ diff --git a/ge/graph/passes/mds_kernels/mds_kernel_factory.h b/ge/graph/passes/mds_kernels/mds_kernel_factory.h new file mode 100644 index 00000000..1f3375c2 --- /dev/null +++ b/ge/graph/passes/mds_kernels/mds_kernel_factory.h @@ -0,0 +1,102 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_KERNEL_FACTORY_H_ +#define MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_KERNEL_FACTORY_H_ + +#include +#include +#include +#include + +#include "common/ge/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "graph/graph.h" + +using std::string; + +namespace ge { +class DeploySchedulerKernel; + +/// +/// @brief DeploySchedulerKernel create factory +/// +class KernelFactory { + public: + // KernelCreator(function), type definition + using KERNEL_CREATOR_FUN = std::function(void)>; + + /// + /// Get singleton instance + /// + static KernelFactory &Instance() { + static KernelFactory instance; + return instance; + } + + /// + /// create DeploySchedulerKernel + /// @param [in] op_type operation type + /// + std::shared_ptr Create(const std::string &op_type) { + std::map::iterator iter = creator_map_.find(op_type); + if (iter != creator_map_.end()) { + return iter->second(); + } + return nullptr; + } + + // DeploySchedulerKernel registration function to register different types of DeploySchedulerKernel to the factory + class Register { + public: + /// + /// @brief Constructor + /// @param [in] type operation type + /// @param [in| fun DeploySchedulerKernel function of the operation + /// + Register(const string &type, const KERNEL_CREATOR_FUN &fun) { + KernelFactory::Instance().RegisterCreator(type, fun); + } + ~Register() = default; + }; + + protected: + KernelFactory() = default; + ~KernelFactory() = default; + + // register creator, this function will call in the constructor + void RegisterCreator(const string &type, const KERNEL_CREATOR_FUN &fun) { + std::map::iterator iter = creator_map_.find(type); + if (iter != creator_map_.end()) { + GELOGW("KernelFactory::RegisterCreator: %s creator already exist", type.c_str()); + return; + } + + creator_map_[type] = fun; + } + + private: + std::map creator_map_{}; +}; + +#define REGISTER_MDS_KERNEL(type, clazz) \ + std::shared_ptr Creator_##type##_Kernel() { \ + std::shared_ptr ptr = nullptr; \ + ptr = MakeShared(); \ + return ptr; \ + } \ + KernelFactory::Register g_##type##_Kernel_Creator(type, Creator_##type##_Kernel) +} // namespace ge +#endif //MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_KERNEL_FACTORY_H_ diff --git a/ge/graph/passes/mds_kernels/mds_utils.cc b/ge/graph/passes/mds_kernels/mds_utils.cc new file mode 100644 index 00000000..b942b8ab --- /dev/null +++ b/ge/graph/passes/mds_kernels/mds_utils.cc @@ -0,0 +1,476 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "./mds_utils.h" +namespace ge { +namespace { +// for count +thread_local int64_t data_slice_count = 0; +thread_local int64_t data_gather_count = 0; +thread_local int64_t data_reduce_count = 0; +const std::string kPrefix = "mds"; +} // namespace +int64_t MdsUtils::GetNLocation(Format fmt) { + int64_t loc = kNInvalidLocation; + switch (fmt) { + case FORMAT_NCHW: + case FORMAT_NHWC: + loc = kNLocation0; + break; + case FORMAT_CHWN: + case FORMAT_HWCN: + loc = kNLocation3; + break; + default: + GELOGE(FAILED, "[MDS]unsupported format:%d %s", fmt, TypeUtils::FormatToSerialString(fmt).c_str()); + } + return loc; +} +int64_t MdsUtils::GetHLocation(Format fmt) { + int64_t loc = kHInvalidLocation; + switch (fmt) { + case FORMAT_HWCN: + loc = kHLocation0; + break; + case FORMAT_NHWC: + case FORMAT_CHWN: + loc = kHLocation1; + break; + case FORMAT_NCHW: + loc = kHLocation2; + default: + GELOGE(FAILED, "[MDS]unsupported format:%d %s", fmt, TypeUtils::FormatToSerialString(fmt).c_str()); + } + return loc; +} +int64_t MdsUtils::GetIndexByFormat(const GeTensorDescPtr &ge_tensor_desc, CutType type) { + Format fmt = ge_tensor_desc->GetFormat(); + switch (type) { + case kCutN: + return GetNLocation(fmt); + case kCutH: + return GetHLocation(fmt); + default:; + } + GELOGE(FAILED, "[MDS]invalid CutType:%d", type); + return kInvalidIndex; +} +bool MdsUtils::IsDistributedDeploySupported(const GeTensorDescPtr &ge_tensor_desc, CutType type) { + if (ge_tensor_desc == nullptr) { + REPORT_INNER_ERROR("E19999", "invalid input param: tensor is null!"); + GELOGE(FAILED, "[MDS]invalid input param: tensor is null!"); + return false; + } + if (type != kCutN && type != kCutH) { + REPORT_INNER_ERROR("E19999", "invalid CutType:%d", type); + GELOGE(FAILED, "[MDS]invalid CutType:%d", type); + return false; + } + int64_t cut_index = GetIndexByFormat(ge_tensor_desc, type); + if (cut_index == kInvalidIndex) { + REPORT_INNER_ERROR("E19999", "invalid index param:%ld", cut_index); + GELOGE(FAILED, "[MDS]", "invalid index param:%ld", cut_index); + return false; + } + auto dims = ge_tensor_desc->GetShape().GetDims(); + if (cut_index < 0 || cut_index >= dims.size()) { + REPORT_INNER_ERROR("E19999", "cut_index %ld for CutType %d is out of range of dims size %zu", cut_index, type, + dims.size()); + GELOGE(FAILED, "[MDS]", "cut_index %ld for CutType %d is out of range of dims size %zu", cut_index, type, + dims.size()); + return false; + } + if (dims[cut_index] % kDeployNumber != 0) { + GELOGW("[MDS] cut_index %ld for CutType %d with dim %ld can not deploy", cut_index, type, dims[cut_index]); + return false; + } + vector cut_support_info; + if (!(AttrUtils::GetListInt(*ge_tensor_desc, ATTR_NAME_CUT_INFO, cut_support_info))) { + REPORT_INNER_ERROR("E19999", "call GetlistInt failed"); + GELOGE(FAILED, "[MDS]", "call GetlistInt failed"); + return false; + } + if (cut_index < 0 || cut_index >= cut_support_info.size()) { + REPORT_INNER_ERROR("E19999", "cut_index %ld for CutType %d is out of range of cut_support_info size %zu", cut_index, + type, cut_support_info.size()); + GELOGE(FAILED, "[MDS]", "cut_index %ld for CutType %d is out of range of cut_support_info size %zu", cut_index, + type, cut_support_info.size()); + return false; + } + if (cut_support_info[cut_index] < kNotSupport || cut_support_info[cut_index] > kAnyCutSupported) { + REPORT_INNER_ERROR("E19999", "invalid cut info value:%ld", cut_support_info[cut_index]); + GELOGE(FAILED, "[MDS]", "invalid cut info value:%ld", cut_support_info[cut_index]); + return false; + } + return cut_support_info[cut_index] & kSplitCutSupported; +} +Status MdsUtils::DistributedDeploy(const GeTensorDescPtr &ge_tensor_desc, CutType type, int64_t deploy_number) { + GE_CHECK_NOTNULL(ge_tensor_desc); + auto index = MdsUtils::GetIndexByFormat(ge_tensor_desc, type); + auto dims = ge_tensor_desc->GetShape().GetDims(); + + REQUIRE(index < dims.size(), "[DistributedDeploy] failed, index %ld should less than %zu", index, dims.size()); + auto dim_after_deploy = dims[index] / deploy_number; + MDS_REQUIRE_SUCCESS(ge_tensor_desc->MutableShape().SetDim(index, dim_after_deploy), + "[DistributedDeploy] update shape failed"); + return SUCCESS; +} +Status MdsUtils::SetAttrForHcomNode(const OpDescPtr &hcom_op, int64_t fission_factor, const std::string &group_name) { + GE_CHECK_NOTNULL(hcom_op); + REQUIRE(fission_factor > kDefaultFissionFactor, "fission_factor %ld need be bigger than %ld", fission_factor, + kDefaultFissionFactor); + REQUIRE(ge::AttrUtils::SetInt(hcom_op, ATTR_NAME_FISSION_FACTOR, fission_factor), + "Failed to set attr fission_factor %ld for op:%s(%s)", fission_factor, hcom_op->GetName().c_str(), + hcom_op->GetType().c_str()); + + if (!group_name.empty()) { + REQUIRE(ge::AttrUtils::SetStr(hcom_op, HCOM_ATTR_GROUP, group_name), "Failed to set attr group %s for op:%s(%s)", + group_name.c_str(), hcom_op->GetName().c_str(), hcom_op->GetType().c_str()); + } + return SUCCESS; +} +bool MdsUtils::IsMDSNeeded() { + std::string device_type; + if (ge::GetContext().GetOption(ge::OPTION_DEVICE_TYPE, device_type) && device_type == kDefaultDeviceType) { + GELOGI("[MDS]device type is %s, skip mds", device_type.c_str()); + return false; + } + // TODO: Parse the configuration file of the system to get the sys_config_exe_unit + std::string sys_config_exe_unit = "DIE"; + return device_type != sys_config_exe_unit; +} + +Status MdsUtils::SetDeployInfo(const ComputeGraphPtr &compute_graph, const NodePtr &input_node) { + GE_CHECK_NOTNULL(compute_graph); + GELOGD("[MDS]%s SetDeployInfo start", compute_graph->GetName().c_str()); + // build deploy info + vector deploy_info; + GE_CHECK_NOTNULL(input_node); + for (int64_t j = 0; j < kDeployNumber; j++) { + int64_t device_id = j; + GeAttrValue::LIST_TENSOR graph_inputs; + GeTensorPtr graph_input = MakeShared(input_node->GetOpDesc()->GetOutputDesc(0)); + vector data{static_cast(device_id)}; + graph_input->SetData(data); + // For now, only one graph_input + graph_inputs.push_back(graph_input); + + GeAttrValue::NAMED_ATTRS thread_instance; + thread_instance.SetName(std::to_string(device_id)); + (void)thread_instance.SetAttr(kAttrDeviceId, GeAttrValue::CreateFrom(device_id)); + // TODO:Change to enumeration from RTS header file + (void)thread_instance.SetAttr(kAttrDeviceType, GeAttrValue::CreateFrom("MultiMode")); + (void)thread_instance.SetAttr(kAttrGraphName, GeAttrValue::CreateFrom(compute_graph->GetName())); + (void)thread_instance.SetAttr(kAttrGraphInputs, GeAttrValue::CreateFrom(graph_inputs)); + deploy_info.emplace_back(thread_instance); + GELOGD("[MDS]%s SetDeployInfo on device id: %d", compute_graph->GetName().c_str(), device_id); + } + // set deploy info + REQUIRE(ge::AttrUtils::SetListNamedAttrs(*compute_graph, ATTR_NAME_DEPLOY_INFO, deploy_info), + "Set attr failed for graph %s", compute_graph->GetName().c_str()); + + return SUCCESS; +} + +CutType MdsUtils::TryGetGraphCutType(const ComputeGraphPtr &compute_graph) { + bool is_unknown_graph = false; + if (GraphUtils::IsUnknownShapeGraph(compute_graph)) { + GELOGI("Graph %s is unknown shape graph", compute_graph->GetName().c_str()); + is_unknown_graph = true; + } + CutType selected_cut_type = kNoCut; + for (const auto &data : compute_graph->GetInputNodes()) { + GELOGI("Get graph input %s %s", data->GetName().c_str(), data->GetType().c_str()); + auto data_n_index = MdsUtils::GetIndexByFormat(data->GetOpDesc()->MutableOutputDesc(0), kCutN); + auto data_n_dim = data->GetOpDesc()->GetOutputDesc(0).GetShape().GetDim(data_n_index); + auto data_h_index = MdsUtils::GetIndexByFormat(data->GetOpDesc()->MutableOutputDesc(0), kCutH); + auto data_h_dim = data->GetOpDesc()->GetOutputDesc(0).GetShape().GetDim(data_h_index); + if (data_n_dim == -1 && data_h_dim == -1) { + selected_cut_type = kDynamicCutAll; + break; + } + if (data_n_dim % kDeployNumber == 0) { + is_unknown_graph ? selected_cut_type = kDynamicCutN : selected_cut_type = kCutN; + break; + } + if (data_h_dim % kDeployNumber == 0) { + is_unknown_graph ? selected_cut_type = kDynamicCutH : selected_cut_type = kCutH; + } + } + return selected_cut_type; +} +Status MdsUtils::SetDeployInfo(const ComputeGraphPtr &compute_graph, + const std::multimap &deploys, const std::string &device_type) { + GE_CHECK_NOTNULL(compute_graph); + GELOGD("[MDS]%s SetDeployInfo start", compute_graph->GetName().c_str()); + + // build deploy info + vector deploy_info; + for (const auto &pair : deploys) { + int64_t device_id = pair.first; + GeAttrValue::NAMED_ATTRS thread_instance; + thread_instance.SetName(std::to_string(device_id)); + (void)thread_instance.SetAttr(kAttrNeedReturnResult, + GeAttrValue::CreateFrom(deploy_info.empty() ? true : false)); + (void)thread_instance.SetAttr(kAttrDeviceId, GeAttrValue::CreateFrom(device_id)); + (void)thread_instance.SetAttr(kAttrDeviceType, GeAttrValue::CreateFrom(device_type)); + (void)thread_instance.SetAttr(kAttrGraphName, GeAttrValue::CreateFrom(compute_graph->GetName())); + (void)thread_instance.SetAttr(kAttrGraphInputs, GeAttrValue::CreateFrom(pair.second)); + deploy_info.emplace_back(thread_instance); + GELOGD("[MDS]%s SetDeployInfo on device id: %d", compute_graph->GetName().c_str(), device_id); + } + // set deploy info + REQUIRE(ge::AttrUtils::SetListNamedAttrs(*compute_graph, ATTR_NAME_DEPLOY_INFO, deploy_info), + "Set attr failed for graph %s", compute_graph->GetName().c_str()); + + return SUCCESS; +} + +Status MdsUtils::DataGather(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst) { + auto src_node = src->GetOwnerNode(); + GE_CHECK_NOTNULL(src_node); + auto dst_node = dst->GetOwnerNode(); + GE_CHECK_NOTNULL(dst_node); + auto src_graph = src_node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(src_graph); + std::string node_name_suffix("_" + kPrefix + "_" + std::to_string(data_gather_count)); + auto hcom_allgather_node = + AddDynamicInputOutputNode(src_graph, HCOMALLGATHER, HCOMALLGATHER + node_name_suffix, 1, 1); + GE_CHECK_NOTNULL(hcom_allgather_node); + MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(src, {dst}, hcom_allgather_node), + "[DataGather] failed between %s and %s", src_node->GetName().c_str(), + dst_node->GetName().c_str()); + MDS_REQUIRE_SUCCESS(MdsUtils::SetAttrForHcomNode(hcom_allgather_node->GetOpDesc(), kDeployNumber, kDefaultGroup), + "[DataGather]set attr for node for %s(%s) failed", hcom_allgather_node->GetName().c_str(), + hcom_allgather_node->GetType().c_str()); + REQUIRE(ge::AttrUtils::SetInt(hcom_allgather_node->GetOpDesc(), HCOM_ATTR_RANK_SIZE, kDefaultRankSize), + "Failed to set attr reduction type %s for op:%s(%s)", kDefaultReduction.c_str(), + hcom_allgather_node->GetName().c_str(), hcom_allgather_node->GetType().c_str()); + MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(hcom_allgather_node, false), + "[DataGather] %s call infershape failed", hcom_allgather_node->GetName().c_str()); + data_gather_count++; + return SUCCESS; +} + +// gradients->ApplyMomentum +// we want to reduce gradients on different device(die), so graph topo changed to +// gradients->hcomallreducemean->ApplyMomentum; Because 'mean' is not currently supported by hcomallreduce, +// topo will end up like gradients->hcomallreducesum->div->ApplyMomentum +Status MdsUtils::DataReduce(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst) { + auto src_node = src->GetOwnerNode(); + GE_CHECK_NOTNULL(src_node); + auto dst_node = dst->GetOwnerNode(); + GE_CHECK_NOTNULL(dst_node); + auto src_graph = src_node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(src_graph); + NodePtr all_reduce_node = nullptr; + if (NeedInsertHcomAllReduce(src_node, all_reduce_node)) { + MDS_REQUIRE_SUCCESS(ConstructReduceNode(src_graph, src, dst, all_reduce_node), + "[DataReduce] construct allreduce node for %s failed", all_reduce_node->GetName().c_str()); + GE_CHECK_NOTNULL(all_reduce_node); + } else { + GE_CHECK_NOTNULL(all_reduce_node); + MDS_REQUIRE_SUCCESS(MdsUtils::SetAttrForHcomNode(all_reduce_node->GetOpDesc(), kDeployNumber), + "[DataReduce][Modify] set attr for allreduce node for %s failed", + all_reduce_node->GetName().c_str()); + } + return SUCCESS; +} + +// tensor t with shape like [n,c,h,w], we want get [0:2/n, c, h, w] and [2/n : n, c, h, w] on different +// device; To achieve this goal, we use slice nodes. +// slice(t, [i * n/2, 0, 0, 0], [n/2, c, h, w]) i=0,1 +// slice three input like : t->slice; data(0,1)->mul(n/2)->pack[i*n/2,0,0,0]->slice; const(n,c,h,w)->slice +Status MdsUtils::DataSlice(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, NodePtr &input_node) { + auto src_node = src->GetOwnerNode(); + GE_CHECK_NOTNULL(src_node); + auto dst_node = dst->GetOwnerNode(); + GE_CHECK_NOTNULL(dst_node); + auto src_graph = src_node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(src_graph); + if (input_node == nullptr) { + std::string input_node_name = std::string(DATA) + "_" + kPrefix + "_" + std::to_string(0); + input_node = AddSingleInputOutputNode(src_graph, input_node_name, DATA); + AddInputNode(input_node); + } + GeTensorDesc tensor = src_node->GetOpDesc()->GetOutputDesc(src->GetIdx()); + NodePtr slice_node = nullptr; + MDS_REQUIRE_SUCCESS(ConstructSliceNode(src_graph, tensor, input_node.get(), slice_node), + "[DataSlice] construct slice node for %s failed", src_node->GetName().c_str()); + GE_CHECK_NOTNULL(slice_node); + MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(src, {dst}, slice_node), "[DataSlice] failed between %s and %s", + src_node->GetName().c_str(), dst_node->GetName().c_str()); + MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(slice_node, false), "[DataSlice] %s call infer shape failed", + slice_node->GetName().c_str()); + return SUCCESS; +} + +Status MdsUtils::ConstructSliceNode(const ComputeGraphPtr &src_graph, const GeTensorDesc &tensor, Node *input_node, + NodePtr &slice_node) { + vector slice_sizes = tensor.GetShape().GetDims(); + // TODO: Express with graph structure + slice_sizes[0] /= kDeployNumber; + vector ge_tensors; + GeTensorDesc ge_tensor_desc; + ge_tensor_desc.SetDataType(DT_INT64); + MDS_REQUIRE_SUCCESS(PassUtils::ConstructTensorDescWithData(ge_tensor_desc, slice_sizes, ge_tensors), + "[ConstructTensorDescWithData] failed"); + GeTensorPtr slice_size_tensor = ge_tensors[0]; + auto const_node_slice_size = AddConstNodeToGraph(slice_size_tensor, src_graph); + + vector slice_offset_other_dim{0}; + ge_tensors.clear(); + MDS_REQUIRE_SUCCESS(PassUtils::ConstructTensorDescWithData(ge_tensor_desc, slice_offset_other_dim, ge_tensors, true), + "[ConstructTensorDescWithData] failed"); + GeTensorPtr slice_offset_tensor = ge_tensors[0]; + auto const_node_slice_offset = AddConstNodeToGraph(slice_offset_tensor, src_graph); + + vector slice_offset_first_dim{slice_sizes[0]}; + ge_tensors.clear(); + MDS_REQUIRE_SUCCESS(PassUtils::ConstructTensorDescWithData(ge_tensor_desc, slice_offset_first_dim, ge_tensors, true), + "[ConstructTensorDescWithData] failed"); + GeTensorPtr slice_offset_first_dim_tensor = ge_tensors[0]; + auto const_node_slice_offset_first_dim = AddConstNodeToGraph(slice_offset_first_dim_tensor, src_graph); + std::string node_name_suffix("_" + kPrefix + "_" + std::to_string(data_slice_count)); + NodePtr mul_node = AddDynamicInputOutputNode(src_graph, MUL, MUL + node_name_suffix, 2, 1); + GE_CHECK_NOTNULL(input_node); + MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(input_node->GetOutDataAnchor(0), mul_node->GetInDataAnchor(0)), + "[ConstructSliceNode] add edge failed"); + MDS_REQUIRE_SUCCESS( + GraphUtils::AddEdge(const_node_slice_offset_first_dim->GetOutDataAnchor(0), mul_node->GetInDataAnchor(1)), + "[ConstructSliceNode] add edge failed"); + MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(mul_node, false), "[DataSlice] %s call infer shape failed", + mul_node->GetName().c_str()); + NodePtr pack_node = AddDynamicInputOutputNode(src_graph, PACK, PACK + node_name_suffix, slice_sizes.size(), 1); + bool is_first_input = true; + for (const auto &in_anchor : pack_node->GetAllInDataAnchors()) { + if (is_first_input) { + MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(mul_node->GetOutDataAnchor(0), in_anchor), + "[ConstructSliceNode] add edge failed"); + is_first_input = false; + } else { + MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(const_node_slice_offset->GetOutDataAnchor(0), in_anchor), + "[ConstructSliceNode] add edge failed"); + } + } + MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(pack_node, false), "[DataSlice] %s call infer shape failed", + pack_node->GetName().c_str()); + slice_node = AddDynamicInputOutputNode(src_graph, SLICE, SLICE + node_name_suffix, 3, 1); + MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(pack_node->GetOutDataAnchor(0), slice_node->GetInDataAnchor(1)), + "[ConstructSliceNode] add edge failed"); + MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(const_node_slice_size->GetOutDataAnchor(0), slice_node->GetInDataAnchor(2)), + "[ConstructSliceNode] add edge failed"); + ++data_slice_count; + return SUCCESS; +} + +NodePtr MdsUtils::AddSingleInputOutputNode(const ComputeGraphPtr &graph, const string &name, const string &type, + const GeTensorDesc &tensor) { + GELOGI("Begin to create op: %s", name.c_str()); + OpDescBuilder op_desc_builder(name, type); + OpDescPtr op_desc = op_desc_builder.AddInput("x", tensor).AddOutput("y", tensor).Build(); + if (op_desc == nullptr) { + REPORT_CALL_ERROR("E19999", "Create op_desc:%s(%s) failed", name.c_str(), type.c_str()); + GELOGE(FAILED, "[Create][OpDesc] failed, name:%s(%s).", name.c_str(), type.c_str()); + return nullptr; + } + NodePtr node = graph->AddNode(op_desc); + if (node == nullptr) { + REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", op_desc->GetName().c_str(), + op_desc->GetType().c_str(), graph->GetName().c_str()); + GELOGE(FAILED, "[Add][Node] %s(%s) to graph:%s failed", op_desc->GetName().c_str(), op_desc->GetType().c_str(), + graph->GetName().c_str()); + return nullptr; + } + return node; +} + +NodePtr MdsUtils::AddDynamicInputOutputNode(const ComputeGraphPtr &graph, const std::string &type, + const std::string &node_name, size_t input_num, size_t output_num) { + GELOGI("Begin to create op: %s", node_name.c_str()); + OpDescBuilder op_desc_builder(node_name, type); + OpDescPtr op_desc = op_desc_builder.AddDynamicInput("x", input_num).AddDynamicOutput("y", output_num).Build(); + if (op_desc == nullptr) { + REPORT_CALL_ERROR("E19999", "Create op_desc:%s(%s) failed", node_name.c_str(), type.c_str()); + GELOGE(FAILED, "[Create][OpDesc] failed, name:%s(%s).", node_name.c_str(), type.c_str()); + return nullptr; + } + NodePtr node = graph->AddNode(op_desc); + if (node == nullptr) { + REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", op_desc->GetName().c_str(), + op_desc->GetType().c_str(), graph->GetName().c_str()); + GELOGE(FAILED, "[Add][Node] %s(%s) to graph:%s failed", op_desc->GetName().c_str(), op_desc->GetType().c_str(), + graph->GetName().c_str()); + return nullptr; + } + return node; +} + +NodePtr MdsUtils::AddConstNodeToGraph(GeTensorPtr &tensor, const ComputeGraphPtr &graph) { + auto const_desc = OpDescUtils::CreateConstOp(tensor); + if (const_desc == nullptr) { + REPORT_CALL_ERROR("E19999", "Create Const op failed"); + GELOGE(OUT_OF_MEMORY, "[Create][ConstOp] failed"); + return nullptr; + } + + if (graph == nullptr) { + GELOGW("input param graph is null"); + return nullptr; + } + return graph->AddNodeFront(const_desc); +} + +Status MdsUtils::ConstructReduceNode(const ComputeGraphPtr &src_graph, const OutDataAnchorPtr &src, + const InDataAnchorPtr &dst, NodePtr &reduce_node) { + std::string node_name_suffix("_" + kPrefix + "_" + std::to_string(data_reduce_count)); + reduce_node = AddDynamicInputOutputNode(src_graph, HCOMALLREDUCE, HCOMALLREDUCE + node_name_suffix, 1, 1); + + MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(src, {dst}, reduce_node), + "[DataReduce] failed insert %s between %s and %s", reduce_node->GetName().c_str(), + src->GetOwnerNode()->GetName().c_str(), dst->GetOwnerNode()->GetName().c_str()); + MDS_REQUIRE_SUCCESS(MdsUtils::SetAttrForHcomNode(reduce_node->GetOpDesc(), kDeployNumber, kDefaultGroup), + "[DataReduce][Create] set attr for allreduce node for %s failed", reduce_node->GetName().c_str()); + REQUIRE(ge::AttrUtils::SetStr(reduce_node->GetOpDesc(), HCOM_ATTR_REDUCE_TYPE, kDefaultReduction), + "Failed to set attr reduction type %s for op:%s(%s)", kDefaultReduction.c_str(), + reduce_node->GetName().c_str(), reduce_node->GetType().c_str()); + MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(reduce_node, false), "[DataReduce] %s call infershape failed", + reduce_node->GetName().c_str()); + auto div_node = AddDynamicInputOutputNode(src_graph, REALDIV, REALDIV + node_name_suffix, 2, 1); + vector slice_sizes{kDeployNumber}; + vector ge_tensors; + GeTensorDesc ge_tensor_desc; + ge_tensor_desc.SetDataType(DT_INT64); + MDS_REQUIRE_SUCCESS(PassUtils::ConstructTensorDescWithData(ge_tensor_desc, slice_sizes, ge_tensors), + "[ConstructReduceNode] failed"); + REQUIRE(!ge_tensors.empty(), "[ConstructReduceNode] failed"); + auto const_node_div_input = AddConstNodeToGraph(ge_tensors[0], src_graph); + MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(const_node_div_input->GetOutDataAnchor(0), div_node->GetInDataAnchor(1)), + "[ConstructSliceNode] add edge failed"); + MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(reduce_node->GetOutDataAnchor(0), {dst}, div_node), + "[DataReduce] failed insert %s between %s and %s", div_node->GetName().c_str(), + reduce_node->GetName().c_str(), dst->GetOwnerNode()->GetName().c_str()); + MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(div_node, false), "[DataReduce] %s call infershape failed", + div_node->GetName().c_str()); + return SUCCESS; +} + +bool MdsUtils::NeedInsertHcomAllReduce(const NodePtr &src_node, NodePtr &allreduce_node) { + // TODO: recognize that the graph is originally a multi-p model, that is, there is already an allreduce node, + // so there is no need to insert i + return true; +} + +} // namespace ge diff --git a/ge/graph/passes/mds_kernels/mds_utils.h b/ge/graph/passes/mds_kernels/mds_utils.h new file mode 100644 index 00000000..f5199b14 --- /dev/null +++ b/ge/graph/passes/mds_kernels/mds_utils.h @@ -0,0 +1,130 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_UTILS_H_ +#define MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_UTILS_H_ + +#include "graph/ge_context.h" +#include "common/op/ge_op_utils.h" +#include "graph/utils/type_utils.h" +#include "graph/utils/graph_utils.h" +#include "graph/debug/ge_attr_define.h" +#include "ge/ge_api_types.h" +#include "common/ge/ge_util.h" +#include "graph/compute_graph.h" +#include "graph/shape_refiner.h" +#include "graph/debug/ge_op_types.h" +#include "framework/common/types.h" +#include "graph/utils/op_desc_utils.h" +#include "../pass_utils.h" + +#define REQUIRE(cond, ...) \ + do { \ + if (!(cond)) { \ + REPORT_INNER_ERROR("E19999", __VA_ARGS__); \ + GELOGE(FAILED, "[MDS]" __VA_ARGS__); \ + return FAILED; \ + } \ + } while (0) + +#define MDS_REQUIRE_NOT_NULL(cond, ...) REQUIRE(((cond) != nullptr), __VA_ARGS__) +#define MDS_REQUIRE_SUCCESS(cond, ...) REQUIRE(((cond) == SUCCESS), __VA_ARGS__) +#define MDS_REQUIRE_GRAPH_SUCCESS(cond, ...) REQUIRE(((cond) == GRAPH_SUCCESS), __VA_ARGS__) +namespace ge { +namespace { +// Invalid location index +const int64_t kInvalidIndex = -1; +enum NCutIndex { kNLocation0 = 0, kNLocation1, kNLocation2, kNLocation3, kNInvalidLocation = -1 }; +enum HCutIndex { kHLocation0 = 0, kHLocation1, kHLocation2, kHLocation3, kHInvalidLocation = -1 }; + +// NCHW dim N index +const int32_t kNchwDimIdxN = 0; +// NCHW dim C index +const int32_t kNchwDimIdxC = 1; +// NCHW dim H index +const int32_t kNchwDimIdxH = 2; +// NCHW dim W index +const int32_t kNchwDimIdxW = 3; +// default die number +const uint32_t kDeployNumber = 2; +enum CutType { kNoCut = 0, kCutN, kCutH, kDynamicCutN, kDynamicCutH, kDynamicCutAll }; +enum TensorCutInfo { kNotSupport = 0, kSplitCutSupported, kAnyCutSupported = 3 }; + +const int64_t kDefaultFissionFactor = 1; +const int64_t kDefaultRankSize = 1; +const std::string kDefaultGroup = "hccl_world_group"; +const std::string kDefaultReduction = "sum"; +const char *const kDefaultDeviceType = "DEFAULT_DEVICE_TYPE"; +const char *const kDefaultExecUnit = "DEFAULT_DEVICE_TYPE"; + +// deploy info +const char *const kAttrNeedReturnResult = "_need_return_result"; +const char *const kAttrDeviceType = "_device_type"; +const char *const kDieDeviceTypeValue = "MultiMode"; +const char *const kAttrDeviceId = "_device_id"; +const char *const kAttrGraphName = "_graph_name"; +const char *const kAttrGraphInputs = "_graph_inputs"; +using GraphInputs = vector; +using DeviceId = int64_t; +using GraphInputNodes = vector; +} // namespace +class MdsUtils { + public: + // Parse the configuration file and determine whether to enable MDS based on the value of device_type. + static bool IsMDSNeeded(); + static int64_t GetNLocation(Format fmt); + static int64_t GetHLocation(Format fmt); + + static int64_t GetIndexByFormat(const GeTensorDescPtr &ge_tensor_desc, CutType type); + static bool IsDistributedDeploySupported(const GeTensorDescPtr &ge_tensor_desc, CutType type); + static Status SetAttrForHcomNode(const OpDescPtr &hcom_op, int64_t fission_factor, + const std::string &group_name = ""); + /// @param [in] index 切分的轴 + /// @param [in] deploy_number 切分的份数 + static Status DistributedDeploy(const GeTensorDescPtr &ge_tensor_desc, CutType type, + int64_t deploy_number = kDeployNumber); + // Sets the information, notifies the number of threads to be started during the + // loading phase, the device on which each thread should run, and constructs different input data on each device. + static Status SetDeployInfo(const ComputeGraphPtr &compute_graph, const NodePtr &input_node); + static Status SetDeployInfo(const ComputeGraphPtr &compute_graph, const std::multimap &deploys, + const std::string &device_type = kDieDeviceTypeValue); + // Get cut policy for whole graph + static CutType TryGetGraphCutType(const ComputeGraphPtr &compute_graph); + static GraphInputNodes GetInputNodes() { + return input_nodes_; + } + static void AddInputNode(const NodePtr &input_node) { + input_nodes_.push_back(input_node); + } + + static Status DataGather(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst); + static Status DataReduce(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst); + static Status DataSlice(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, NodePtr &input_node); + + private: + static GraphInputNodes input_nodes_; + static NodePtr AddDynamicInputOutputNode(const ComputeGraphPtr &graph, const string &type, const string &node_name, + size_t input_num, size_t output_num); + static NodePtr AddSingleInputOutputNode(const ComputeGraphPtr &graph, const string &name, const string &type, + const GeTensorDesc &tensor = GeTensorDesc()); + static Status ConstructReduceNode(const ComputeGraphPtr &src_graph, const OutDataAnchorPtr &src, + const InDataAnchorPtr &dst, NodePtr &reduce_node); + static Status ConstructSliceNode(const ComputeGraphPtr &src_graph, const GeTensorDesc &tensor, Node *node, + NodePtr &slice_node); + static bool NeedInsertHcomAllReduce(const NodePtr &src_node, NodePtr &allreduce_node); + static NodePtr AddConstNodeToGraph(GeTensorPtr &tensor, const ComputeGraphPtr &graph); +}; +} // namespace ge +#endif // MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_UTILS_H_ diff --git a/ge/graph/passes/mds_kernels/variable_mds_kernel.cc b/ge/graph/passes/mds_kernels/variable_mds_kernel.cc new file mode 100644 index 00000000..a84a7ca3 --- /dev/null +++ b/ge/graph/passes/mds_kernels/variable_mds_kernel.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "variable_mds_kernel.h" +#include "mds_kernel_factory.h" +namespace ge { + + +Status VariableDeploySchedulerKernel::CutN(const ge::NodePtr& node_ptr) { + GE_CHECK_NOTNULL(node_ptr); + if (MdsUtils::IsDistributedDeploySupported(node_ptr->GetOpDesc()->MutableOutputDesc(0), kCutN)) { + return MdsUtils::DistributedDeploy(node_ptr->GetOpDesc()->MutableOutputDesc(0), kCutN); + } + return SUCCESS; + +} +Status VariableDeploySchedulerKernel::CutH(const ge::NodePtr& node_ptr) { + GE_CHECK_NOTNULL(node_ptr); + if (MdsUtils::IsDistributedDeploySupported(node_ptr->GetOpDesc()->MutableOutputDesc(0), kCutH)) { + return MdsUtils::DistributedDeploy(node_ptr->GetOpDesc()->MutableOutputDesc(0), kCutH); + } + return SUCCESS; +} +REGISTER_MDS_KERNEL(VARIABLE, VariableDeploySchedulerKernel); + +} + + diff --git a/ge/graph/passes/mds_kernels/variable_mds_kernel.h b/ge/graph/passes/mds_kernels/variable_mds_kernel.h new file mode 100644 index 00000000..c2307450 --- /dev/null +++ b/ge/graph/passes/mds_kernels/variable_mds_kernel.h @@ -0,0 +1,28 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_VARIABLE_MDS_KERNEL_H_ +#define MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_VARIABLE_MDS_KERNEL_H_ + +#include "base_mds_kernel.h" + +namespace ge { +class VariableDeploySchedulerKernel : public DeploySchedulerKernel { + public: + Status CutN(const ge::NodePtr& node_ptr) override; + Status CutH(const ge::NodePtr& node_ptr) override; +}; +} // namespace ge +#endif //MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_VARIABLE_MDS_KERNEL_H_ diff --git a/ge/graph/passes/mds_pass.cc b/ge/graph/passes/mds_pass.cc new file mode 100644 index 00000000..aa411ed5 --- /dev/null +++ b/ge/graph/passes/mds_pass.cc @@ -0,0 +1,177 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "./mds_pass.h" + +namespace ge { +Status ModelDeploySchedulerPass::Run(ComputeGraphPtr graph) { + GE_CHECK_NOTNULL(graph); + compute_graph_ = graph; + if (!MdsUtils::IsMDSNeeded()) { + return SUCCESS; + } + GELOGI("[MDS][%s] start to deploy.", GetGraphName()); + MDS_REQUIRE_SUCCESS(SMDPProcess(), "[MDS][SMDPProcess] failed, graph_name:[%s]", GetGraphName()); + MDS_REQUIRE_SUCCESS(CutProcess(), "[MDS][CutProcess] failed, graph_name:[%s]", GetGraphName()); + MDS_REQUIRE_SUCCESS(SMDPProcess(false), "[MDS][SMDPProcess] failed, graph_name:[%s]", GetGraphName()); + MDS_REQUIRE_SUCCESS(SwapProcess(), "[MDS][SwapProcess] failed, graph_name:[%s]", GetGraphName()); + MDS_REQUIRE_SUCCESS(PiplineProcess(), "[MDS][PiplineProcess] failed, graph_name:[%s]", GetGraphName()); + MDS_REQUIRE_SUCCESS(SetDeployInfo(), "[MDS][SetDeployInfo] failed, graph_name:[%s]", GetGraphName()); + + GELOGI("[MDS][%s] deploy successfully.", graph->GetName().c_str()); + return SUCCESS; +} + +Status ModelDeploySchedulerPass::CutProcess() { + GE_CHECK_NOTNULL(compute_graph_); + if (!compute_graph_->GetAllSubgraphs().empty() || compute_graph_->GetParentGraph() != nullptr) { + GELOGW("[MDS][CutProcess] graph with subgraphs is not supported now. graph_name:[%s]", GetGraphName()); + return SUCCESS; + } + auto type = MdsUtils::TryGetGraphCutType(compute_graph_); + switch (type) { + case kCutN: + MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph_), "[MDS][CutNProcessImply] failed, graph_name:[%s]", + GetGraphName()); + break; + case kCutH: + MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph_), "[MDS][CutHProcessImply] failed, graph_name:[%s]", + GetGraphName()); + break; + case kDynamicCutN: + MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph_, true), "[MDS][CutNProcessImply] failed, graph_name:[%s]", + GetGraphName()); + break; + case kDynamicCutH: + MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph_, true), "[MDS][CutHProcessImply] failed, graph_name:[%s]", + GetGraphName()); + break; + case kDynamicCutAll: + MDS_REQUIRE_SUCCESS(DynamicCutAll(compute_graph_), "[MDS][DynamicCutAll] failed, graph_name:[%s]", + GetGraphName()); + break; + default: + GELOGI("[MDS][CutProcess] could not cut, just return. graph_name:[%s]", GetGraphName()); + return SUCCESS; + } +} + +Status ModelDeploySchedulerPass::CutNProcessImply(const ComputeGraphPtr &compute_graph, bool is_dynamic) { + GE_CHECK_NOTNULL(compute_graph); + // step 0: Cut + for (const auto &node : compute_graph->GetDirectNode()) { + auto op_kernel = mds_cut_pass::GetKernelByType(node); + if (op_kernel == nullptr) { + op_kernel = DeploySchedulerKernel::Instance(); + } + if (is_dynamic) { + MDS_REQUIRE_SUCCESS(op_kernel->DynamicCutN(node), "[MDS][DYNAMIC_CUTN] failed, node:[%s]", + node->GetName().c_str()); + } else { + MDS_REQUIRE_SUCCESS(op_kernel->CutN(node), "[MDS][CUTN] failed, node:[%s]", node->GetName().c_str()); + } + bool is_grad_compute_node = false; + if (ge::AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_GRADIENT_NODE, is_grad_compute_node) && + is_grad_compute_node) { + grad_compute_nodes_.push_back(node); + } + } + // TODO:for single output multi reference insertion allgather, allreduce nodes, do breadth fusion optimization + MDS_REQUIRE_SUCCESS(HcomNodeFusionProcess(), "[MDS][CUTN][HcomNodeFusionProcess] failed, compute graph:[%s]", + compute_graph->GetName().c_str()); + return SUCCESS; +} + +Status ModelDeploySchedulerPass::CutHProcessImply(const ComputeGraphPtr &compute_graph, bool is_dynamic) { + GE_CHECK_NOTNULL(compute_graph); + for (NodePtr &node : compute_graph->GetDirectNode()) { + auto op_kernel = mds_cut_pass::GetKernelByType(node); + if (op_kernel == nullptr) { + op_kernel = DeploySchedulerKernel::Instance(); + } + if (is_dynamic) { + MDS_REQUIRE_SUCCESS(op_kernel->DynamicCutH(node), "[MDS][DYNAMIC_CUTH] failed, node:[%s]", + node->GetName().c_str()); + } else { + MDS_REQUIRE_SUCCESS(op_kernel->CutH(node), "[MDS][CUTH] failed, node:[%s]", node->GetName().c_str()); + } + } + return SUCCESS; +} + +Status ModelDeploySchedulerPass::DynamicCutAll(const ComputeGraphPtr &compute_graph) { + std::vector input_nodes; + std::vector output_nodes; + auto compute_graph0 = GraphUtils::CloneGraph(compute_graph, "", input_nodes, output_nodes); + auto compute_graph1 = GraphUtils::CloneGraph(compute_graph, "", input_nodes, output_nodes); + MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph0, true), "[MDS][CutNProcessImply] failed, graph_name:[%s]", + compute_graph0->GetName().c_str()); + MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph1, true), "[MDS][CutHProcessImply] failed, graph_name:[%s]", + compute_graph1->GetName().c_str()); + // TODO:Create a case node, put the two graphs under the two branches of case + return SUCCESS; +} + +Status ModelDeploySchedulerPass::SMDPProcess(bool before_cut) { + if (before_cut) { + MDS_REQUIRE_SUCCESS(SMDPModelState(), "[SMDPProcess][SMDPModelState] failed, graph_name:[%s]", GetGraphName()); + MDS_REQUIRE_SUCCESS(SMDPWeight(), "[SMDPProcess][SMDPWeight] failed, graph_name:[%s]", GetGraphName()); + } else { + MDS_REQUIRE_SUCCESS(SMDPGradient(), "[SMDPProcess][SMDPGradient] failed, graph_name:[%s]", GetGraphName()); + } + return SUCCESS; +} + +Status ModelDeploySchedulerPass::SetDeployInfo() { + vector deployInfo; + REQUIRE(!ge::AttrUtils::GetListNamedAttrs(compute_graph_, ATTR_NAME_DEPLOY_INFO, deployInfo), + "%s already has deployed before!", GetGraphName()); + std::multimap deploys; + for (int64_t j = 0; j < kDeployNumber; j++) { + int64_t device_id = j; + GraphInputs graph_inputs; + // For now, only one input_node in input_nodes + for (const auto &input_node : MdsUtils::GetInputNodes()) { + GE_CHECK_NOTNULL(input_node); + GeTensorPtr graph_input = MakeShared(input_node->GetOpDesc()->GetOutputDesc(0)); + vector data{static_cast(device_id)}; + graph_input->SetData(data); + graph_inputs.push_back(graph_input); + } + deploys.emplace(j, graph_inputs); + } + return MdsUtils::SetDeployInfo(compute_graph_, deploys); +} + +Status ModelDeploySchedulerPass::SwapProcess() { + return SUCCESS; +} +Status ModelDeploySchedulerPass::PiplineProcess() { + return SUCCESS; +} +Status ModelDeploySchedulerPass::HcomNodeFusionProcess() { + return SUCCESS; +} +Status ModelDeploySchedulerPass::SMDPModelState() { + return SUCCESS; +} +Status ModelDeploySchedulerPass::SMDPWeight() { + return SUCCESS; +} +Status ModelDeploySchedulerPass::SMDPGradient() { + return SUCCESS; +} +} // namespace ge diff --git a/ge/graph/passes/mds_pass.h b/ge/graph/passes/mds_pass.h new file mode 100644 index 00000000..4a18cc1e --- /dev/null +++ b/ge/graph/passes/mds_pass.h @@ -0,0 +1,71 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_H_ +#define MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_H_ + +#include "graph/types.h" +#include "ge/ge_api.h" +#include "graph/debug/ge_attr_define.h" +#include "inc/graph_pass.h" +#include "./mds_kernels/base_mds_kernel.h" +#include "ge/ge_api_types.h" +#include "./mds_kernels/mds_utils.h" + +namespace ge { +class ModelDeploySchedulerPass : public GraphPass { + public: + Status Run(ge::ComputeGraphPtr graph) override; + + private: + // Part0:Process Func + // cut and dynamic cut + Status CutProcess(); + Status CutNProcessImply(const ComputeGraphPtr &compute_graph, bool is_dynamic = false); + Status CutHProcessImply(const ComputeGraphPtr &compute_graph, bool is_dynamic = false); + Status DynamicCutAll(const ComputeGraphPtr &compute_graph); + // smdp + Status SMDPProcess(bool before_cut = true); + Status SMDPModelState(); + Status SMDPGradient(); + Status SMDPWeight(); + // swap + Status SwapProcess(); + // pipline + Status PiplineProcess(); + // set delpoyinfo + Status SetDeployInfo(); + + // Part1: Utils Func + // std::vector GetNodeInputsSupportCut(NodePtr node, uint64_t cut_index); + // std::vector GetNodeOutputsSupportCut(NodePtr node, uint64_t cut_index); + Status HcomNodeFusionProcess(); + Status GetAllModelStateVar(); + Status GetAllWeightVar(); + std::vector GetAllGradComputeNodes() { + return grad_compute_nodes_; + } + const char *GetGraphName() const { + return compute_graph_->GetName().c_str(); + } + + // members + std::vector model_state_vars_; + std::vector model_weight_vars_; + std::vector grad_compute_nodes_; + ComputeGraphPtr compute_graph_ = nullptr; +}; +} // namespace ge +#endif // MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_H_ diff --git a/inc/external/ge/ge_api_types.h b/inc/external/ge/ge_api_types.h index 6f5bbfbf..d9d6dc19 100644 --- a/inc/external/ge/ge_api_types.h +++ b/inc/external/ge/ge_api_types.h @@ -28,6 +28,7 @@ namespace ge { // Option key: graph run mode const char *const OPTION_GRAPH_RUN_MODE = "ge.graphRunMode"; +const char *const OPTION_DEVICE_TYPE = "ge.deviceType"; // Option key: ome init const char *const OPTION_EXEC_SESSION_ID = "ge.exec.sessionId";